diff --git a/.idea/workspace.xml b/.idea/workspace.xml new file mode 100644 index 00000000..359cff88 --- /dev/null +++ b/.idea/workspace.xml @@ -0,0 +1,237 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + { + "associatedIndex": 6 +} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1772062719090 + + + + + + \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 8de6eb4c..ec77f290 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -746,7 +746,7 @@ dependencies = [ "futures-core", "futures-sink", "http", - "indexmap 2.2.6", + "indexmap 2.13.0", "slab", "tokio", "tokio-util", @@ -761,9 +761,9 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" -version = "0.14.5" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" [[package]] name = "heck" @@ -972,13 +972,14 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.6" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", - "hashbrown 0.14.5", + "hashbrown 0.16.1", "serde", + "serde_core", ] [[package]] @@ -1041,9 +1042,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.81" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec48937a97411dcb524a265206ccd4c90bb711fca92b2792c407f268825b9305" +checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" dependencies = [ "once_cell", "wasm-bindgen", @@ -1188,6 +1189,7 @@ dependencies = [ "chrono", "clap", "env_logger", + "futures", "log", "moq-native-ietf", "moq-transport", @@ -1228,6 +1230,7 @@ dependencies = [ "bytes", "clap", "env_logger", + "futures", "log", "moq-catalog", "moq-native-ietf", @@ -1580,6 +1583,7 @@ version = "0.11.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" dependencies = [ + "aws-lc-rs", "bytes", "fastbloom", "getrandom 0.3.3", @@ -2125,7 +2129,7 @@ version = "1.0.145" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.13.0", "itoa", "memchr", "ryu", @@ -2165,7 +2169,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.2.6", + "indexmap 2.13.0", "schemars 0.9.0", "schemars 1.0.4", "serde", @@ -2187,6 +2191,17 @@ dependencies = [ "syn", ] +[[package]] +name = "sfv" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d471eaefb14f4b30032525bdb124b36e55ba9cb1292080e06f1a236cd10fe87" +dependencies = [ + "base64", + "indexmap 2.13.0", + "ref-cast", +] + [[package]] name = "sha1_smol" version = "1.0.0" @@ -2670,9 +2685,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.104" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1da10c01ae9f1ae40cbfac0bac3b1e724b320abfcf52229f80b547c0d250e2d" +checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" dependencies = [ "cfg-if", "once_cell", @@ -2681,20 +2696,6 @@ dependencies = [ "wasm-bindgen-shared", ] -[[package]] -name = "wasm-bindgen-backend" -version = "0.2.104" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "671c9a5a66f49d8a47345ab942e2cb93c7d1d0339065d4f8139c486121b43b19" -dependencies = [ - "bumpalo", - "log", - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-shared", -] - [[package]] name = "wasm-bindgen-futures" version = "0.4.42" @@ -2709,9 +2710,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.104" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ca60477e4c59f5f2986c50191cd972e3a50d8a95603bc9434501cf156a9a119" +checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2719,31 +2720,43 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.104" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f07d2f20d4da7b26400c9f4a0511e6e0345b040694e8a75bd41d578fa4421d7" +checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" dependencies = [ + "bumpalo", "proc-macro2", "quote", "syn", - "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.104" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bad67dc8b2a1a6e5448428adec4c3e84c43e561d8c9ee8a9e5aabeb193ec41d1" +checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" dependencies = [ "unicode-ident", ] +[[package]] +name = "web-streams" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48465a648c14f53f6d8319b95bc336a44627f6aa6bd94270463777af8ed65deb" +dependencies = [ + "thiserror 2.0.17", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "854ba17bb104abfb26ba36da9729addc7ce7f06f5c0f90f3c391f8461cca21f9" dependencies = [ "js-sys", "wasm-bindgen", @@ -2761,56 +2774,73 @@ dependencies = [ [[package]] name = "web-transport" -version = "0.3.0" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5793aee9b4cf993212042c6b1656d877de9ad32b9eb3281d7bc95f4dce3f6591" +checksum = "23c3f78eca5afa10eb7b8ab64b4e5e521a006f0cbd88de09e44d55ef37e8855a" dependencies = [ "bytes", - "thiserror 1.0.61", + "thiserror 2.0.17", + "url", "web-transport-quinn", "web-transport-wasm", ] [[package]] name = "web-transport-proto" -version = "0.2.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df0922f754c890ceb9741c00a0f5c730aaa4b52fe8772934a0ad19a03daee0ca" +checksum = "0225d295c8ac00a2e9a498aefeaf3f3c6186da12a251c938189b15b82ea22808" dependencies = [ "bytes", "http", - "thiserror 1.0.61", + "sfv", + "thiserror 2.0.17", + "tokio", "url", ] [[package]] name = "web-transport-quinn" -version = "0.3.0" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d248fb83873166e1fce7e91370deb15bd5213cf4352242e32ccd4abc8aeb2cef" +checksum = "82e77c81fe4cf56c1049e07c6ed9c00862a967010fe9da4f5e02dc7f4d71fdac" dependencies = [ "bytes", "futures", "http", - "log", "quinn", - "quinn-proto", - "thiserror 1.0.61", + "rustls 0.23.31", + "rustls-native-certs 0.8.1", + "thiserror 2.0.17", "tokio", + "tracing", "url", "web-transport-proto", + "web-transport-trait", +] + +[[package]] +name = "web-transport-trait" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb67841c4a481ca3c1412ee4c9f463987401991e1ddc000903df2124f3dc85e9" +dependencies = [ + "bytes", ] [[package]] name = "web-transport-wasm" -version = "0.1.0" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64be28348e18cb1f44e4c8733dc2bd9520d782be840b2b978724dfd1b1bdefa3" +checksum = "6816176def6e8df1636c8fc2c401f37add41ccad1518705e209d9a7ada3d144c" dependencies = [ "bytes", "js-sys", + "thiserror 2.0.17", + "url", "wasm-bindgen", "wasm-bindgen-futures", + "web-streams", "web-sys", ] diff --git a/Cargo.toml b/Cargo.toml index c903feaa..93663988 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ members = [ resolver = "2" [workspace.dependencies] -web-transport = "0.3" +web-transport = "0.10" env_logger = "0.11" log = { version = "0.4", features = ["std"] } diff --git a/moq-clock-ietf/Cargo.toml b/moq-clock-ietf/Cargo.toml index b717051b..854687a7 100644 --- a/moq-clock-ietf/Cargo.toml +++ b/moq-clock-ietf/Cargo.toml @@ -22,6 +22,7 @@ url = "2" # Async stuff tokio = { version = "1", features = ["full"] } +futures = "0.3" # CLI, logging, error handling clap = { version = "4", features = ["derive"] } diff --git a/moq-clock-ietf/src/clock.rs b/moq-clock-ietf/src/clock.rs index a96863ee..ada2bbd6 100644 --- a/moq-clock-ietf/src/clock.rs +++ b/moq-clock-ietf/src/clock.rs @@ -45,6 +45,7 @@ impl Publisher { group_id: next_group_id as u64, subgroup_id: 0, priority: 0, + header_type: None, }) .context("failed to create minute segment")?; @@ -66,6 +67,7 @@ impl Publisher { priority: 127, payload: time_str.clone().into_bytes().into(), extension_headers: Default::default(), + status: None, }) .context("failed to write datagram")?; diff --git a/moq-clock-ietf/src/main.rs b/moq-clock-ietf/src/main.rs index 7d0f6951..eab274b7 100644 --- a/moq-clock-ietf/src/main.rs +++ b/moq-clock-ietf/src/main.rs @@ -1,6 +1,7 @@ use moq_native_ietf::quic; use anyhow::Context; +use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; mod cli; mod clock; @@ -10,11 +11,36 @@ use cli::Cli; use moq_transport::{ coding::TrackNamespace, - serve, - session::{Publisher, Subscriber}, + serve::{self, TracksReader}, + session::{Publisher, SessionError, Subscriber}, }; -/// The main entry point for the MoQ Clock IETF example. +async fn serve_subscriptions( + mut publisher: Publisher, + tracks: TracksReader, +) -> Result<(), SessionError> { + let mut tasks: FuturesUnordered> = + FuturesUnordered::new(); + + loop { + tokio::select! { + Some(subscribed) = publisher.subscribed() => { + let info = subscribed.info.clone(); + let tracks = tracks.clone(); + log::info!("serving subscribe: {:?}", info); + + tasks.push(async move { + if let Err(err) = Publisher::serve_subscribe(subscribed, tracks).await { + log::warn!("failed serving subscribe: {:?}, error: {}", info, err); + } + }.boxed()); + } + _ = tasks.next(), if !tasks.is_empty() => {} + else => return Ok(()), + } + } +} + #[tokio::main] async fn main() -> anyhow::Result<()> { env_logger::init(); @@ -59,10 +85,16 @@ async fn main() -> anyhow::Result<()> { let track_writer = tracks_writer.create(&config.track).unwrap(); let clock_publisher = clock::Publisher::new_datagram(track_writer.datagrams()?); + let publish_ns = publisher + .publish_namespace(tracks_reader.namespace.clone()) + .await + .context("failed to register namespace")?; + tokio::select! { res = session.run() => res.context("session error")?, res = clock_publisher.run() => res.context("clock error")?, - res = publisher.announce(tracks_reader) => res.context("failed to serve tracks")?, + res = serve_subscriptions(publisher, tracks_reader) => res.context("failed to serve tracks")?, + res = publish_ns.closed() => res.context("namespace closed")?, } } else { log::info!("publishing clock via streams"); @@ -75,10 +107,16 @@ async fn main() -> anyhow::Result<()> { let track_writer = tracks_writer.create(&config.track).unwrap(); let clock_publisher = clock::Publisher::new(track_writer.subgroups()?); + let publish_ns = publisher + .publish_namespace(tracks_reader.namespace.clone()) + .await + .context("failed to register namespace")?; + tokio::select! { res = session.run() => res.context("session error")?, res = clock_publisher.run() => res.context("clock error")?, - res = publisher.announce(tracks_reader) => res.context("failed to serve tracks")?, + res = serve_subscriptions(publisher, tracks_reader) => res.context("failed to serve tracks")?, + res = publish_ns.closed() => res.context("namespace closed")?, } } } else { diff --git a/moq-native-ietf/Cargo.toml b/moq-native-ietf/Cargo.toml index adc5b302..41eeb1c9 100644 --- a/moq-native-ietf/Cargo.toml +++ b/moq-native-ietf/Cargo.toml @@ -14,7 +14,7 @@ categories = ["multimedia", "network-programming", "web-programming"] [dependencies] moq-transport = { path = "../moq-transport", version = "0.12" } web-transport = { workspace = true } -web-transport-quinn = "0.3" +web-transport-quinn = { version = "0.11", default-features = false, features = ["ring"] } rustls = { version = "0.23", features = ["ring"] } rustls-pemfile = "2" diff --git a/moq-native-ietf/src/quic.rs b/moq-native-ietf/src/quic.rs index bed04d83..e5d6ac8d 100644 --- a/moq-native-ietf/src/quic.rs +++ b/moq-native-ietf/src/quic.rs @@ -160,7 +160,7 @@ impl Endpoint { if let Some(mut config) = config.tls.server { config.alpn_protocols = vec![ - web_transport_quinn::ALPN.to_vec(), + web_transport_quinn::ALPN.as_bytes().to_vec(), moq_transport::setup::ALPN.to_vec(), ]; config.key_log = Arc::new(rustls::KeyLogFile::new()); @@ -305,22 +305,24 @@ impl Server { server_name, ); - let session = match alpn.as_bytes() { - web_transport_quinn::ALPN => { - // Wait for the CONNECT request. - let request = web_transport_quinn::accept(conn) - .await - .context("failed to receive WebTransport request")?; - - // Accept the CONNECT request. - request - .ok() - .await - .context("failed to respond to WebTransport request")? - } - // A bit of a hack to pretend like we're a WebTransport session - moq_transport::setup::ALPN => conn.into(), - _ => anyhow::bail!("unsupported ALPN: {}", alpn), + let alpn_bytes = alpn.as_bytes(); + let session = if alpn_bytes == web_transport_quinn::ALPN.as_bytes() { + // Wait for the WebTransport CONNECT request (includes H3 SETTINGS exchange). + let request = web_transport_quinn::Request::accept(conn) + .await + .context("failed to receive WebTransport request")?; + + // Accept the CONNECT request. + request + .ok() + .await + .context("failed to respond to WebTransport request")? + } else if alpn_bytes == moq_transport::setup::ALPN { + // Raw QUIC mode — create a session with no H3 framing. + let request = url::Url::parse("moqt://localhost").unwrap(); + web_transport_quinn::Session::raw(conn, request, web_transport_quinn::proto::ConnectResponse::default()) + } else { + anyhow::bail!("unsupported ALPN: {}", alpn) }; Ok((session.into(), connection_id_hex)) @@ -373,7 +375,7 @@ impl Client { // TODO support connecting to both ALPNs at the same time config.alpn_protocols = vec![match url.scheme() { - "https" => web_transport_quinn::ALPN.to_vec(), + "https" => web_transport_quinn::ALPN.as_bytes().to_vec(), "moqt" => moq_transport::setup::ALPN.to_vec(), _ => anyhow::bail!("url scheme must be 'https' or 'moqt'"), }]; @@ -426,8 +428,15 @@ impl Client { .to_string(); let session = match url.scheme() { - "https" => web_transport_quinn::connect_with(connection, url).await?, - "moqt" => connection.into(), + "https" => { + // Build a ConnectRequest with the MoQT version as the WebTransport subprotocol. + // Per draft-15+, version negotiation uses ALPN (raw QUIC) or + // wt-available-protocols (WebTransport) instead of CLIENT_SETUP versions. + let request = web_transport_quinn::proto::ConnectRequest::new(url.clone()) + .with_protocol(std::str::from_utf8(moq_transport::setup::ALPN).unwrap()); + web_transport_quinn::Session::connect(connection, request).await? + } + "moqt" => web_transport_quinn::Session::raw(connection, url.clone(), web_transport_quinn::proto::ConnectResponse::default()), _ => unreachable!(), }; diff --git a/moq-pub/Cargo.toml b/moq-pub/Cargo.toml index 13085ca3..d37ebbaf 100644 --- a/moq-pub/Cargo.toml +++ b/moq-pub/Cargo.toml @@ -23,6 +23,7 @@ bytes = "1" # Async stuff tokio = { version = "1", features = ["full"] } +futures = "0.3" # CLI, logging, error handling clap = { version = "4", features = ["derive"] } diff --git a/moq-pub/src/main.rs b/moq-pub/src/main.rs index cd9350fc..c259b440 100644 --- a/moq-pub/src/main.rs +++ b/moq-pub/src/main.rs @@ -4,11 +4,16 @@ use url::Url; use anyhow::Context; use clap::Parser; +use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use tokio::io::AsyncReadExt; use moq_native_ietf::quic; use moq_pub::Media; -use moq_transport::{coding::TrackNamespace, serve, session::Publisher}; +use moq_transport::{ + coding::TrackNamespace, + serve::{self, TracksReader}, + session::{Publisher, SessionError}, +}; #[derive(Parser, Clone)] pub struct Cli { @@ -39,6 +44,32 @@ pub struct Cli { pub tls: moq_native_ietf::tls::Args, } +async fn serve_subscriptions( + mut publisher: Publisher, + tracks: TracksReader, +) -> Result<(), SessionError> { + let mut tasks: FuturesUnordered> = + FuturesUnordered::new(); + + loop { + tokio::select! { + Some(subscribed) = publisher.subscribed() => { + let info = subscribed.info.clone(); + let tracks = tracks.clone(); + log::info!("serving subscribe: {:?}", info); + + tasks.push(async move { + if let Err(err) = Publisher::serve_subscribe(subscribed, tracks).await { + log::warn!("failed serving subscribe: {:?}, error: {}", info, err); + } + }.boxed()); + } + _ = tasks.next(), if !tasks.is_empty() => {} + else => return Ok(()), + } + } +} + #[tokio::main] async fn main() -> anyhow::Result<()> { env_logger::init(); @@ -71,16 +102,25 @@ async fn main() -> anyhow::Result<()> { connection_id ); - let (session, mut publisher) = Publisher::connect(session) + let (session, publisher) = Publisher::connect(session) .await .context("failed to create MoQ Transport publisher")?; + let namespace = reader.namespace.clone(); + + let publish_ns = publisher + .clone() + .publish_namespace(namespace) + .await + .context("failed to register namespace")?; + + log::info!("namespace registered, starting media and subscription handling"); + tokio::select! { res = session.run() => res.context("session error")?, - res = run_media(media) => { - res.context("media error")? - }, - res = publisher.announce(reader) => res.context("publisher error")?, + res = run_media(media) => res.context("media error")?, + res = serve_subscriptions(publisher, reader) => res.context("publisher error")?, + res = publish_ns.closed() => res.context("publisher error")?, } Ok(()) diff --git a/moq-pub/src/media.rs b/moq-pub/src/media.rs index b46f5473..f1952781 100644 --- a/moq-pub/src/media.rs +++ b/moq-pub/src/media.rs @@ -384,7 +384,12 @@ impl Track { } pub fn end_group(&mut self) { - self.current = None; + // Send EndOfGroup marker before dropping the writer + if let Some(mut current) = self.current.take() { + if let Err(e) = current.end_of_group() { + log::warn!("failed to send EndOfGroup marker: {}", e); + } + } } } diff --git a/moq-relay-ietf/src/consumer.rs b/moq-relay-ietf/src/consumer.rs index 8d636912..d0641740 100644 --- a/moq-relay-ietf/src/consumer.rs +++ b/moq-relay-ietf/src/consumer.rs @@ -3,11 +3,13 @@ use std::sync::Arc; use anyhow::Context; use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use moq_transport::{ - serve::Tracks, - session::{Announced, SessionError, Subscriber}, + coding::KeyValuePairs, + message::PublishOk, + serve::{ServeError, Tracks}, + session::{PublishNamespaceReceived, PublishReceived, SessionError, Subscriber}, }; -use crate::{Coordinator, Locals, Producer}; +use crate::{Coordinator, Locals, Producer, SubscriberRegistry}; /// Consumer of tracks from a remote Publisher #[derive(Clone)] @@ -16,6 +18,8 @@ pub struct Consumer { locals: Locals, coordinator: Arc, forward: Option, // Forward all announcements to this subscriber + subscriber_registry: Option, + session_id: u64, } impl Consumer { @@ -30,28 +34,63 @@ impl Consumer { locals, coordinator, forward, + subscriber_registry: None, + session_id: 0, } } - /// Run the consumer to serve announce requests. - pub async fn run(mut self) -> Result<(), SessionError> { - let mut tasks = FuturesUnordered::new(); + /// Creates a consumer with a subscriber registry for PUBLISH notifications. + pub fn with_registry( + subscriber: Subscriber, + locals: Locals, + coordinator: Arc, + forward: Option, + subscriber_registry: SubscriberRegistry, + session_id: u64, + ) -> Self { + Self { + subscriber, + locals, + coordinator, + forward, + subscriber_registry: Some(subscriber_registry), + session_id, + } + } + + /// Run the consumer to serve announce requests and track-level publish messages. + pub async fn run(self) -> Result<(), SessionError> { + let mut tasks: FuturesUnordered> = + FuturesUnordered::new(); loop { + let mut subscriber_ns = self.subscriber.clone(); + let mut subscriber_publish = self.subscriber.clone(); + tokio::select! { - // Handle a new announce request - Some(announce) = self.subscriber.announced() => { + Some(publish_ns) = subscriber_ns.publish_ns_recvd() => { let this = self.clone(); tasks.push(async move { - let info = announce.clone(); - log::info!("serving announce: {:?}", info); + let info = publish_ns.clone(); + log::info!("serving publish_namespace: {:?}", info); - // Serve the announce request - if let Err(err) = this.serve(announce).await { - log::warn!("failed serving announce: {:?}, error: {}", info, err) + if let Err(err) = this.serve_publish_namespace(publish_ns).await { + log::warn!("failed serving publish_namespace: {:?}, error: {}", info, err) } - }); + }.boxed()); + }, + Some(publish) = subscriber_publish.publish_received() => { + let this = self.clone(); + + tasks.push(async move { + let info = publish.info.clone(); + log::info!("serving publish (track-level): {:?}", info); + + if let Err(err) = this.serve_publish(publish).await { + log::warn!("failed serving publish: {:?}, error: {}", info, err) + } + }.boxed()); }, _ = tasks.next(), if !tasks.is_empty() => {}, else => return Ok(()), @@ -59,12 +98,14 @@ impl Consumer { } } - /// Serve an announce request. - async fn serve(mut self, mut announce: Announced) -> Result<(), anyhow::Error> { - let mut tasks = FuturesUnordered::new(); + async fn serve_publish_namespace( + mut self, + mut publish_ns: PublishNamespaceReceived, + ) -> Result<(), anyhow::Error> { + let mut tasks: FuturesUnordered>> = + FuturesUnordered::new(); - // Produce the tracks for this announce and return the reader - let (_, mut request, reader) = Tracks::new(announce.namespace.clone()).produce(); + let (writer, mut request, reader) = Tracks::new(publish_ns.namespace.clone()).produce(); // NOTE(mpandit): once the track is pulled from origin, internally it will be relayed // from this metal only, because now coordinator will have entry for the namespace. @@ -78,30 +119,54 @@ impl Consumer { .await?; // Register the local tracks, unregister on drop - let _register = self.locals.register(reader.clone()).await?; - - // Accept the announce with an OK response - announce.ok()?; - - // Forward the announce, if needed - if let Some(mut forward) = self.forward { - tasks.push( - async move { - log::info!("forwarding announce: {:?}", reader.info); - forward - .announce(reader) - .await - .context("failed forwarding announce") - } - .boxed(), - ); + let _register = self.locals.register(reader.clone(), writer).await?; + + publish_ns.ok()?; + + // Notify subscriber registry of the new PUBLISH_NAMESPACE + // This will trigger forwarding to matching SUBSCRIBE_NAMESPACE subscriptions + // Uses session_id for self-exclusion + if let Some(ref registry) = self.subscriber_registry { + let notified = registry.notify_publish_namespace(&publish_ns.namespace, self.session_id); + if notified > 0 { + log::info!( + "notified {} SUBSCRIBE_NAMESPACE subscriptions of PUBLISH_NAMESPACE {:?}", + notified, + publish_ns.namespace + ); + } } + // Forward publish_namespace upstream - keep handle alive in this scope + let _forwarded_publish_ns = if let Some(mut forward) = self.forward.clone() { + let reader_clone = reader.clone(); + log::info!("forwarding publish_namespace: {:?}", reader_clone.info); + match forward.publish_namespace(reader_clone).await { + Ok(publish_ns) => { + if let Err(e) = publish_ns.ok().await { + log::warn!("publish_namespace not accepted: {}", e); + None + } else { + log::info!( + "publish_namespace forwarded and accepted: {:?}", + publish_ns.info.namespace + ); + Some(publish_ns) + } + } + Err(e) => { + log::warn!("failed forwarding publish_namespace: {}", e); + None + } + } + } else { + None + }; + // Serve subscribe requests loop { tokio::select! { - // If the announce is closed, return the error - Err(err) = announce.closed() => return Err(err.into()), + Err(err) = publish_ns.closed() => return Err(err.into()), // Wait for the next subscriber and serve the track. Some(track) = request.next() => { @@ -125,4 +190,144 @@ impl Consumer { } } } + + async fn serve_publish(mut self, publish: PublishReceived) -> Result<(), anyhow::Error> { + let namespace = publish.info.track_namespace.clone(); + let track_name = publish.info.track_name.clone(); + let track_alias = publish.info.track_alias; + let initial_forward = publish.info.forward; + let publish_request_id = publish.info.id; + let track_extensions = publish.info.track_extensions.clone(); + + log::info!( + "received PUBLISH for track: {}/{} (forward={}, extensions={:?})", + namespace, + track_name, + initial_forward, + track_extensions + ); + + // Use auto-register variant to support SUBSCRIBE_NAMESPACE flow + // where PUBLISH can arrive without prior PUBLISH_NAMESPACE + let track_info = self + .locals + .get_or_create_track_info_auto_register(&namespace, &track_name); + + let writer = match track_info.publish_arrived() { + Ok(w) => w, + Err(ServeError::Uninterested) => { + log::info!( + "PUBLISH rejected: already subscribed to {}/{}", + namespace, + track_name + ); + publish.reject(ServeError::Uninterested.code(), "Already subscribed")?; + return Err(ServeError::Uninterested.into()); + } + Err(ServeError::Duplicate) => { + log::info!( + "PUBLISH rejected: already publishing {}/{}", + namespace, + track_name + ); + publish.reject(ServeError::Duplicate.code(), "Already publishing")?; + return Err(ServeError::Duplicate.into()); + } + Err(e) => { + publish.reject(e.code(), &e.to_string())?; + return Err(e.into()); + } + }; + + let reader = track_info.get_reader(); + + self.locals + .insert_track(&namespace, reader) + .context("failed to insert track into namespace")?; + + // Store publish info for forward state management + track_info.set_publish_info(publish_request_id, initial_forward); + + // Store track extensions for forwarding to subscribers + track_info.set_track_extensions(track_extensions); + + // Include forward=1 in PUBLISH_OK to tell publisher to start sending immediately + let mut params = KeyValuePairs::default(); + params.set_intvalue(0x10, 1); // Forward = 1 + + let msg = PublishOk { + id: publish.info.id, + params, + }; + + publish.accept(writer, msg)?; + + log::info!( + "PUBLISH accepted, track {}/{} now in Publishing state (forward={})", + namespace, + track_name, + initial_forward + ); + + // Notify subscriber registry of the new PUBLISH + // This will trigger forwarding to matching SUBSCRIBE_NAMESPACE subscriptions + // Uses session_id for self-exclusion (don't notify the same session that sent the PUBLISH) + if let Some(ref registry) = self.subscriber_registry { + let notified = registry.notify_publish(&namespace, &track_name, track_alias, self.session_id); + if notified > 0 { + log::info!( + "notified {} SUBSCRIBE_NAMESPACE subscriptions of PUBLISH {}/{}", + notified, + namespace, + track_name + ); + } + } + + // If forward=0 (paused), wait for subscribers to request forwarding + // When forward state changes to 1, send REQUEST_UPDATE to publisher + if !initial_forward { + let forward_rx = track_info.forward_receiver(); + if let Some(mut rx) = forward_rx { + log::info!( + "track {}/{} is paused (forward=0), waiting for subscriber to request forwarding", + namespace, + track_name + ); + + // Wait for forward state to change to true + loop { + rx.changed().await.ok(); + if *rx.borrow() { + // Forward state changed to true, send REQUEST_UPDATE + log::info!( + "subscriber arrived for paused track {}/{}, sending REQUEST_UPDATE with forward=1", + namespace, + track_name + ); + + let mut params = KeyValuePairs::default(); + params.set_intvalue(0x10, 1); // Forward = 1 + + let request_update = moq_transport::message::SubscribeUpdate { + id: self.subscriber.get_next_request_id(), + existing_request_id: publish_request_id, + params, + }; + + self.subscriber.send_message(request_update); + log::info!( + "sent REQUEST_UPDATE for track {}/{} (existing_request_id={})", + namespace, + track_name, + publish_request_id + ); + break; + } + } + } + } + + Ok(()) + } } diff --git a/moq-relay-ietf/src/lib.rs b/moq-relay-ietf/src/lib.rs index aac39326..11a9456f 100644 --- a/moq-relay-ietf/src/lib.rs +++ b/moq-relay-ietf/src/lib.rs @@ -36,6 +36,7 @@ mod producer; mod relay; mod remote; mod session; +mod subscriber_registry; mod web; pub use api::*; @@ -46,4 +47,5 @@ pub use producer::*; pub use relay::*; pub use remote::*; pub use session::*; +pub use subscriber_registry::*; pub use web::*; diff --git a/moq-relay-ietf/src/local.rs b/moq-relay-ietf/src/local.rs index 406e6650..e56211b3 100644 --- a/moq-relay-ietf/src/local.rs +++ b/moq-relay-ietf/src/local.rs @@ -1,17 +1,252 @@ use std::collections::hash_map; use std::collections::HashMap; - -use std::sync::{Arc, Mutex}; +use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; +use std::sync::{Arc, Mutex, OnceLock}; use moq_transport::{ coding::TrackNamespace, - serve::{ServeError, TracksReader}, + data::ExtensionHeaders, + serve::{ServeError, Track, TrackReader, TrackWriter, TracksReader, TracksWriter}, }; +use tokio::sync::watch; + +#[repr(u8)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum TrackState { + Pending = 0, + Subscribing = 1, + Subscribed = 2, + Publishing = 3, + Closed = 4, +} + +impl TrackState { + fn from_u8(v: u8) -> Self { + match v { + 0 => TrackState::Pending, + 1 => TrackState::Subscribing, + 2 => TrackState::Subscribed, + 3 => TrackState::Publishing, + _ => TrackState::Closed, + } + } +} + +pub struct TrackInfo { + pub namespace: TrackNamespace, + pub name: String, + + state: AtomicU8, + track_reader: OnceLock, + track_writer: Mutex>, + upstream_subscribe_sent: AtomicBool, + upstream_request_id: Mutex>, + + /// The PUBLISH request ID (set when publisher sends PUBLISH) + publish_request_id: Mutex>, + /// Forward state: true = forwarding, false = paused + /// Publisher watches this to know when to start/stop sending + forward_state: Mutex>>, + /// Receiver side for forward state changes + forward_receiver: Mutex>>, + /// Track extensions from the original PUBLISH message + track_extensions: Mutex>, +} + +impl TrackInfo { + pub fn new(namespace: TrackNamespace, name: String) -> Self { + Self { + namespace, + name, + state: AtomicU8::new(TrackState::Pending as u8), + track_reader: OnceLock::new(), + track_writer: Mutex::new(None), + upstream_subscribe_sent: AtomicBool::new(false), + upstream_request_id: Mutex::new(None), + publish_request_id: Mutex::new(None), + forward_state: Mutex::new(None), + forward_receiver: Mutex::new(None), + track_extensions: Mutex::new(None), + } + } + + pub fn get_reader(&self) -> TrackReader { + self.ensure_track_created(); + self.track_reader.get().unwrap().clone() + } + + pub fn should_subscribe_upstream(&self) -> bool { + let state = self.state(); + + if state == TrackState::Publishing { + return false; + } + + !self.upstream_subscribe_sent.swap(true, Ordering::SeqCst) + } + + pub fn mark_subscribe_sent(&self, request_id: u64) { + *self.upstream_request_id.lock().unwrap() = Some(request_id); + + let _ = self.state.compare_exchange( + TrackState::Pending as u8, + TrackState::Subscribing as u8, + Ordering::SeqCst, + Ordering::SeqCst, + ); + } + + pub fn subscribe_ok_received(&self) { + let _ = self.state.compare_exchange( + TrackState::Subscribing as u8, + TrackState::Subscribed as u8, + Ordering::SeqCst, + Ordering::SeqCst, + ); + } + + pub fn publish_arrived(&self) -> Result { + self.ensure_track_created(); + + let current_state = self.state(); + + if current_state == TrackState::Subscribed { + return Err(ServeError::Uninterested); + } + + if current_state == TrackState::Publishing { + return Err(ServeError::Duplicate); + } + + self.state + .store(TrackState::Publishing as u8, Ordering::SeqCst); + + self.track_writer + .lock() + .unwrap() + .take() + .ok_or(ServeError::Duplicate) + } + -/// Registry of local tracks + pub fn state(&self) -> TrackState { + TrackState::from_u8(self.state.load(Ordering::SeqCst)) + } + + pub fn is_publishing(&self) -> bool { + self.state() == TrackState::Publishing + } + + /// Set up forward state tracking when PUBLISH is received. + /// Returns the initial forward value that was set. + pub fn set_publish_info(&self, request_id: u64, initial_forward: bool) { + *self.publish_request_id.lock().unwrap() = Some(request_id); + + let (tx, rx) = watch::channel(initial_forward); + *self.forward_state.lock().unwrap() = Some(tx); + *self.forward_receiver.lock().unwrap() = Some(rx); + + log::debug!( + "set_publish_info: track {}/{} request_id={} initial_forward={}", + self.namespace, + self.name, + request_id, + initial_forward + ); + } + + /// Get the PUBLISH request ID + pub fn publish_request_id(&self) -> Option { + *self.publish_request_id.lock().unwrap() + } + + /// Get current forward state + pub fn is_forwarding(&self) -> bool { + self.forward_receiver + .lock() + .unwrap() + .as_ref() + .map(|rx| *rx.borrow()) + .unwrap_or(true) // Default to true if not set (legacy behavior) + } + + /// Request forwarding to start (called when a subscriber arrives). + /// Returns true if the state changed from false to true. + pub fn request_forward(&self) -> bool { + if let Some(tx) = self.forward_state.lock().unwrap().as_ref() { + let current = *tx.borrow(); + if !current { + let _ = tx.send(true); + log::info!( + "request_forward: track {}/{} forward state changed 0 -> 1", + self.namespace, + self.name + ); + return true; + } + } + false + } + + /// Get a receiver for forward state changes (for the publisher to watch) + pub fn forward_receiver(&self) -> Option> { + self.forward_receiver.lock().unwrap().clone() + } + + /// Set track extensions from the original PUBLISH message + pub fn set_track_extensions(&self, extensions: ExtensionHeaders) { + *self.track_extensions.lock().unwrap() = Some(extensions); + } + + /// Get track extensions (cloned) + pub fn track_extensions(&self) -> Option { + self.track_extensions.lock().unwrap().clone() + } + + pub fn take_writer_for_upstream(&self) -> Result { + self.ensure_track_created(); + + let current_state = self.state(); + + if current_state == TrackState::Publishing { + return Err(ServeError::Duplicate); + } + + if current_state == TrackState::Subscribing || current_state == TrackState::Subscribed { + return Err(ServeError::Duplicate); + } + + self.state + .store(TrackState::Subscribing as u8, Ordering::SeqCst); + + self.track_writer + .lock() + .unwrap() + .take() + .ok_or(ServeError::Duplicate) + } + + fn ensure_track_created(&self) { + self.track_reader.get_or_init(|| { + let (writer, reader) = Track::new(self.namespace.clone(), self.name.clone()).produce(); + *self.track_writer.lock().unwrap() = Some(writer); + reader + }); + } +} + +struct LocalsEntry { + /// reader and writer hold the readers and writers for a namespace + reader: TracksReader, + writer: TracksWriter, + /// tracks holds the individual tracks for a namespace + tracks: Mutex>>, +} + +/// Locals is a map of TrackNamespace to LocalsEntry #[derive(Clone)] pub struct Locals { - lookup: Arc>>, + lookup: Arc>>, } impl Default for Locals { @@ -20,7 +255,6 @@ impl Default for Locals { } } -/// Local tracks registry. impl Locals { pub fn new() -> Self { Self { @@ -28,13 +262,19 @@ impl Locals { } } - /// Register new local tracks. - pub async fn register(&mut self, tracks: TracksReader) -> anyhow::Result { - let namespace = tracks.namespace.clone(); + pub async fn register( + &mut self, + reader: TracksReader, + writer: TracksWriter, + ) -> anyhow::Result { + let namespace = reader.namespace.clone(); - // Insert the tracks(TracksReader) into the lookup table match self.lookup.lock().unwrap().entry(namespace.clone()) { - hash_map::Entry::Vacant(entry) => entry.insert(tracks), + hash_map::Entry::Vacant(entry) => entry.insert(LocalsEntry { + reader, + writer, + tracks: Mutex::new(HashMap::new()), + }), hash_map::Entry::Occupied(_) => return Err(ServeError::Duplicate.into()), }; @@ -46,17 +286,13 @@ impl Locals { Ok(registration) } - /// Retrieve local tracks by namespace using hierarchical prefix matching. - /// Returns the TracksReader for the longest matching namespace prefix. pub fn retrieve(&self, namespace: &TrackNamespace) -> Option { let lookup = self.lookup.lock().unwrap(); - // Find the longest matching prefix let mut best_match: Option = None; let mut best_len = 0; - for (registered_ns, tracks) in lookup.iter() { - // Check if registered_ns is a prefix of namespace + for (registered_ns, entry) in lookup.iter() { if namespace.fields.len() >= registered_ns.fields.len() { let is_prefix = registered_ns .fields @@ -65,7 +301,7 @@ impl Locals { .all(|(a, b)| a == b); if is_prefix && registered_ns.fields.len() > best_len { - best_match = Some(tracks.clone()); + best_match = Some(entry.reader.clone()); best_len = registered_ns.fields.len(); } } @@ -73,6 +309,233 @@ impl Locals { best_match } + + pub fn get_or_create_track_info( + &self, + namespace: &TrackNamespace, + track_name: &str, + ) -> Option> { + let lookup = self.lookup.lock().unwrap(); + + let entry = Self::find_best_match_entry(&lookup, namespace)?; + + // Use full namespace + track_name as key to avoid collisions + let track_key = format!("{}:{}", namespace, track_name); + + let mut tracks = entry.tracks.lock().unwrap(); + + let track_info = tracks + .entry(track_key) + .or_insert_with(|| Arc::new(TrackInfo::new(namespace.clone(), track_name.to_string()))) + .clone(); + + Some(track_info) + } + + /// Get or create track info, auto-registering the namespace if needed. + /// This supports the SUBSCRIBE_NAMESPACE flow where PUBLISH can arrive + /// without a prior PUBLISH_NAMESPACE. + pub fn get_or_create_track_info_auto_register( + &self, + namespace: &TrackNamespace, + track_name: &str, + ) -> Arc { + let mut lookup = self.lookup.lock().unwrap(); + + // Use full namespace + track_name as key to avoid collisions + // when different namespaces have the same track_name + let track_key = format!("{}:{}", namespace, track_name); + + // Check if there's an existing exact-match namespace entry that's stale + // and needs to be removed (this happens when publisher disconnects and reconnects) + let should_remove_namespace = if let Some(entry) = lookup.get(namespace) { + let tracks = entry.tracks.lock().unwrap(); + if let Some(existing) = tracks.get(&track_key) { + // Track exists and is in Publishing state but has no writer = stale + existing.state() == TrackState::Publishing + && existing.track_writer.lock().unwrap().is_none() + } else { + false + } + } else { + false + }; + + if should_remove_namespace { + log::info!( + "removing stale namespace entry {} (track {}/{} was Publishing with no writer)", + namespace, + namespace, + track_name + ); + lookup.remove(namespace); + } + + // First try to find an existing matching namespace entry + if let Some(entry) = Self::find_best_match_entry(&lookup, namespace) { + let mut tracks = entry.tracks.lock().unwrap(); + + return tracks + .entry(track_key.clone()) + .or_insert_with(|| { + Arc::new(TrackInfo::new(namespace.clone(), track_name.to_string())) + }) + .clone(); + } + + // No matching namespace found - auto-register for SUBSCRIBE_NAMESPACE flow + log::info!( + "auto-registering namespace {} for PUBLISH (no prior PUBLISH_NAMESPACE)", + namespace + ); + + let (writer, _request, reader) = + moq_transport::serve::Tracks::new(namespace.clone()).produce(); + + let entry = lookup.entry(namespace.clone()).or_insert(LocalsEntry { + reader, + writer, + tracks: Mutex::new(HashMap::new()), + }); + + let mut tracks = entry.tracks.lock().unwrap(); + tracks + .entry(track_key) + .or_insert_with(|| Arc::new(TrackInfo::new(namespace.clone(), track_name.to_string()))) + .clone() + } + + pub fn get_track_info( + &self, + namespace: &TrackNamespace, + track_name: &str, + ) -> Option> { + let lookup = self.lookup.lock().unwrap(); + + let entry = Self::find_best_match_entry(&lookup, namespace)?; + + // Use full namespace + track_name as key to match get_or_create_track_info + let track_key = format!("{}:{}", namespace, track_name); + let tracks = entry.tracks.lock().unwrap(); + tracks.get(&track_key).cloned() + } + + fn find_best_match_entry<'a>( + lookup: &'a HashMap, + namespace: &TrackNamespace, + ) -> Option<&'a LocalsEntry> { + let mut best_match: Option<&LocalsEntry> = None; + let mut best_len = 0; + + for (registered_ns, entry) in lookup.iter() { + if namespace.fields.len() >= registered_ns.fields.len() { + let is_prefix = registered_ns + .fields + .iter() + .zip(namespace.fields.iter()) + .all(|(a, b)| a == b); + + if is_prefix && registered_ns.fields.len() > best_len { + best_match = Some(entry); + best_len = registered_ns.fields.len(); + } + } + } + + best_match + } + + pub fn insert_track( + &self, + namespace: &TrackNamespace, + track_reader: TrackReader, + ) -> Option<()> { + let mut lookup = self.lookup.lock().unwrap(); + + if let Some(entry) = lookup.get_mut(namespace) { + entry.writer.insert(track_reader) + } else { + None + } + } + + pub fn subscribe_upstream(&self, track_info: Arc) -> Option { + let mut lookup = self.lookup.lock().unwrap(); + + let entry = lookup.get_mut(&track_info.namespace)?; + + let writer = track_info.take_writer_for_upstream().ok()?; + let reader = track_info.get_reader(); + + entry.reader.forward_upstream(writer)?; + + let namespace = track_info.namespace.clone(); + + let entry_mut = lookup + .iter_mut() + .find(|(ns, _)| { + namespace.fields.len() >= ns.fields.len() + && ns + .fields + .iter() + .zip(namespace.fields.iter()) + .all(|(a, b)| a == b) + }) + .map(|(_, e)| e)?; + + entry_mut.writer.insert(reader.clone()); + + Some(reader) + } + + pub fn matching_namespaces(&self, prefix: &TrackNamespace) -> Vec { + let lookup = self.lookup.lock().unwrap(); + + lookup + .keys() + .filter(|ns| { + if ns.fields.len() >= prefix.fields.len() { + prefix + .fields + .iter() + .zip(ns.fields.iter()) + .all(|(a, b)| a == b) + } else { + false + } + }) + .cloned() + .collect() + } + + /// Get all tracks in namespaces matching a prefix that are in Publishing state. + /// Returns (namespace, track_name, TrackInfo) tuples. + pub fn matching_tracks(&self, prefix: &TrackNamespace) -> Vec<(TrackNamespace, String, Arc)> { + let lookup = self.lookup.lock().unwrap(); + + let mut result = Vec::new(); + + for (ns, entry) in lookup.iter() { + // Check if namespace matches prefix + if ns.fields.len() >= prefix.fields.len() + && prefix + .fields + .iter() + .zip(ns.fields.iter()) + .all(|(a, b)| a == b) + { + // Get all tracks in this namespace that are publishing + let tracks = entry.tracks.lock().unwrap(); + for (key, track_info) in tracks.iter() { + if track_info.is_publishing() { + result.push((ns.clone(), track_info.name.clone(), track_info.clone())); + } + } + } + } + + result + } } pub struct Registration { @@ -80,7 +543,6 @@ pub struct Registration { namespace: TrackNamespace, } -/// Deregister local tracks on drop. impl Drop for Registration { fn drop(&mut self) { self.locals.lookup.lock().unwrap().remove(&self.namespace); diff --git a/moq-relay-ietf/src/producer.rs b/moq-relay-ietf/src/producer.rs index 23ea49f3..8149df46 100644 --- a/moq-relay-ietf/src/producer.rs +++ b/moq-relay-ietf/src/producer.rs @@ -1,10 +1,15 @@ use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use moq_transport::{ + coding::{KeyValuePairs, TrackNamespace}, + message, serve::{ServeError, TracksReader}, - session::{Publisher, SessionError, Subscribed, TrackStatusRequested}, + session::{ + PublishNamespace, Publisher, SessionError, SubscribeNamespaceReceived, Subscribed, + TrackStatusRequested, + }, }; -use crate::{Locals, RemotesConsumer}; +use crate::{Locals, RemotesConsumer, SubscriberRegistry}; /// Producer of tracks to a remote Subscriber #[derive(Clone)] @@ -12,6 +17,8 @@ pub struct Producer { publisher: Publisher, locals: Locals, remotes: Option, + subscriber_registry: Option, + session_id: u64, } impl Producer { @@ -20,15 +27,37 @@ impl Producer { publisher, locals, remotes, + subscriber_registry: None, + session_id: 0, } } - /// Announce new tracks to the remote server. - pub async fn announce(&mut self, tracks: TracksReader) -> Result<(), SessionError> { - self.publisher.announce(tracks).await + /// Creates a producer with a subscriber registry. + pub fn with_registry( + publisher: Publisher, + locals: Locals, + remotes: Option, + subscriber_registry: SubscriberRegistry, + session_id: u64, + ) -> Self { + Self { + publisher, + locals, + remotes, + subscriber_registry: Some(subscriber_registry), + session_id, + } + } + + pub async fn publish_namespace( + &mut self, + tracks: TracksReader, + ) -> Result { + self.publisher + .publish_namespace(tracks.namespace.clone()) + .await } - /// Run the producer to serve subscribe requests. pub async fn run(self) -> Result<(), SessionError> { //let mut tasks = FuturesUnordered::new(); let mut tasks: FuturesUnordered> = @@ -37,6 +66,7 @@ impl Producer { loop { let mut publisher_subscribed = self.publisher.clone(); let mut publisher_track_status = self.publisher.clone(); + let mut publisher_subscribe_ns = self.publisher.clone(); tokio::select! { // Handle a new subscribe request @@ -69,28 +99,72 @@ impl Producer { } }.boxed()) }, + Some(subscribe_ns) = publisher_subscribe_ns.subscribe_namespace_received() => { + let this = self.clone(); + + tasks.push(async move { + let info = subscribe_ns.info.clone(); + log::info!("serving subscribe_namespace: {:?}", info); + + if let Err(err) = this.serve_subscribe_namespace(subscribe_ns).await { + log::warn!("failed serving subscribe_namespace: {:?}, error: {}", info, err) + } + }.boxed()) + }, _= tasks.next(), if !tasks.is_empty() => {}, else => return Ok(()), }; } } - /// Serve a subscribe request. async fn serve_subscribe(self, subscribed: Subscribed) -> Result<(), anyhow::Error> { let namespace = subscribed.track_namespace.clone(); let track_name = subscribed.track_name.clone(); - // Check local tracks first, and serve from local if possible - if let Some(mut local) = self.locals.retrieve(&namespace) { - // Pass the full requested namespace, not the announced prefix - if let Some(track) = local.subscribe(namespace.clone(), &track_name) { - log::info!("serving subscribe from local: {:?}", track.info); - return Ok(subscribed.serve(track).await?); + if let Some(track_info) = self + .locals + .get_or_create_track_info(&namespace, &track_name) + { + if track_info.should_subscribe_upstream() { + log::info!( + "subscribe needs upstream request: {}/{}", + namespace, + track_name + ); + + if let Some(reader) = self.locals.subscribe_upstream(track_info.clone()) { + log::info!( + "forwarding subscribe upstream via TrackInfo: {}/{}", + namespace, + track_name + ); + return Ok(subscribed.serve(reader).await?); + } + } + + // If the track is in Publishing state and forward=0, request forwarding + // This will trigger the consumer to send REQUEST_UPDATE to the publisher + if track_info.is_publishing() && !track_info.is_forwarding() { + log::info!( + "subscriber arrived for paused track {}/{}, requesting forward", + namespace, + track_name + ); + track_info.request_forward(); } + + let reader = track_info.get_reader(); + log::info!( + "serving subscribe from local: {}/{} (state: {:?}, forwarding: {})", + namespace, + track_name, + track_info.state(), + track_info.is_forwarding() + ); + return Ok(subscribed.serve(reader).await?); } if let Some(remotes) = self.remotes { - // Check remote tracks second, and serve from remote if possible match remotes.route(&namespace).await { Ok(remote) => { if let Some(remote) = remote { @@ -105,7 +179,7 @@ impl Producer { } } } - // Track not found - close the subscription with not found error + let err = ServeError::not_found_ctx(format!( "track '{}/{}' not found in local or remote tracks", namespace, track_name @@ -114,7 +188,205 @@ impl Producer { Err(err.into()) } - /// Serve a track_status request. + async fn serve_subscribe_namespace( + mut self, + mut subscribe_ns: SubscribeNamespaceReceived, + ) -> Result<(), anyhow::Error> { + let namespace_prefix = subscribe_ns.namespace_prefix.clone(); + + // Register with subscriber registry to receive PUBLISH and PUBLISH_NAMESPACE notifications + // Uses session_id so we can exclude PUBLISH messages from the same session (self-exclusion) + let (_subscription_guard, mut publish_rx, mut publish_ns_rx) = + if let Some(ref registry) = self.subscriber_registry { + let (id, rx, rx_ns) = registry.register(namespace_prefix.clone(), self.session_id); + ( + Some(crate::SubscriptionGuard::new(registry.clone(), id)), + Some(rx), + Some(rx_ns), + ) + } else { + (None, None, None) + }; + + // Accept the subscription (even if no current matches - publisher may arrive later) + subscribe_ns.ok()?; + + log::info!( + "accepted SUBSCRIBE_NAMESPACE for prefix {:?}", + namespace_prefix + ); + + // Send PUBLISH for existing tracks in matching namespaces + // This triggers the client's onMatch callback for track discovery + // Note: We skip PUBLISH_NAMESPACE and send PUBLISH directly - client expects PUBLISH for tracks + let matching_tracks = self.locals.matching_tracks(&namespace_prefix); + log::info!( + "found {} existing tracks matching prefix {:?}", + matching_tracks.len(), + namespace_prefix + ); + + for (ns, track_name, track_info) in matching_tracks { + let track_extensions = track_info.track_extensions().unwrap_or_default(); + log::info!( + "sending PUBLISH for existing track {}/{} (matched prefix {:?}, extensions={:?})", + ns, + track_name, + namespace_prefix, + track_extensions + ); + + let track_reader = track_info.get_reader(); + let mut publisher = self.publisher.clone(); + + tokio::spawn(async move { + match publisher.publish_with_extensions(track_reader.clone(), track_extensions).await { + Ok(published) => { + log::info!( + "sent PUBLISH for existing track {}/{}, waiting for PUBLISH_OK", + ns, + track_name + ); + // serve() waits for PUBLISH_OK before streaming + match published.serve(track_reader).await { + Ok(()) => { + log::info!("existing track {}/{} serving completed", ns, track_name); + } + Err(e) => { + log::warn!("existing track {}/{} serving ended: {}", ns, track_name, e); + } + } + } + Err(e) => { + log::warn!("failed to send PUBLISH for existing track {}/{}: {}", ns, track_name, e); + } + } + }); + } + + // If we have a publish receiver, listen for new PUBLISH and PUBLISH_NAMESPACE notifications + if publish_rx.is_some() || publish_ns_rx.is_some() { + loop { + tokio::select! { + // Wait for the subscription to close + result = subscribe_ns.closed() => { + result?; + break; + } + // Wait for PUBLISH notifications -> forward PUBLISH to subscriber + // Subscriber sends PUBLISH_OK, then relay starts streaming data + notification = async { + if let Some(ref mut rx) = publish_rx { + rx.recv().await + } else { + std::future::pending().await + } + } => { + match notification { + Ok(publish_notif) => { + log::info!( + "received PUBLISH notification for {}/{} on subscription prefix {:?}", + publish_notif.namespace, + publish_notif.track_name, + namespace_prefix + ); + + // Get the TrackReader for this track so we can stream data + if let Some(track_info) = self.locals.get_track_info( + &publish_notif.namespace, + &publish_notif.track_name, + ) { + let track_reader = track_info.get_reader(); + let track_extensions = track_info.track_extensions().unwrap_or_default(); + + // Send PUBLISH and wait for PUBLISH_OK before streaming + let mut publisher = self.publisher.clone(); + let ns = publish_notif.namespace.clone(); + let name = publish_notif.track_name.clone(); + log::info!( + "forwarding PUBLISH for {}/{} with extensions {:?}", + ns, name, track_extensions + ); + tokio::spawn(async move { + match publisher.publish_with_extensions(track_reader.clone(), track_extensions).await { + Ok(published) => { + log::info!( + "sent PUBLISH for {}/{}, waiting for PUBLISH_OK", + ns, name + ); + // serve() waits for PUBLISH_OK before streaming + match published.serve(track_reader).await { + Ok(()) => { + log::info!("track {}/{} serving completed", ns, name); + } + Err(e) => { + log::warn!( + "track {}/{} serving ended: {}", + ns, name, e + ); + } + } + } + Err(e) => { + log::warn!( + "failed to send PUBLISH for {}/{}: {}", + ns, name, e + ); + } + } + }); + } else { + log::warn!( + "no track info found for {}/{}, cannot forward PUBLISH", + publish_notif.namespace, + publish_notif.track_name + ); + } + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + log::warn!("subscription lagged by {} messages", n); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + log::debug!("publish notification channel closed"); + break; + } + } + } + // PUBLISH_NAMESPACE notifications - we don't forward these as NAMESPACE messages + // Client expects PUBLISH for individual tracks, not namespace announcements + notification = async { + if let Some(ref mut rx) = publish_ns_rx { + rx.recv().await + } else { + std::future::pending().await + } + } => { + match notification { + Ok(ns_notif) => { + log::debug!( + "ignoring PUBLISH_NAMESPACE notification for {:?} (client expects PUBLISH for tracks)", + ns_notif.namespace + ); + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + log::warn!("namespace subscription lagged by {} messages", n); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + log::debug!("publish_namespace notification channel closed"); + break; + } + } + } + } + } + } else { + // No registry, just wait for close + subscribe_ns.closed().await?; + } + + Ok(()) + } + async fn serve_track_status( self, mut track_status_requested: TrackStatusRequested, diff --git a/moq-relay-ietf/src/relay.rs b/moq-relay-ietf/src/relay.rs index 0daf6edb..8657e1ab 100644 --- a/moq-relay-ietf/src/relay.rs +++ b/moq-relay-ietf/src/relay.rs @@ -8,6 +8,7 @@ use url::Url; use crate::{ Consumer, Coordinator, Locals, Producer, Remotes, RemotesConsumer, RemotesProducer, Session, + SubscriberRegistry, }; // A type alias for boxed future @@ -58,6 +59,7 @@ pub struct Relay { locals: Locals, remotes: Option<(RemotesProducer, RemotesConsumer)>, coordinator: Arc, + subscriber_registry: SubscriberRegistry, } impl Relay { @@ -107,6 +109,9 @@ impl Relay { } .produce(); + // Create subscriber registry for SUBSCRIBE_NAMESPACE tracking + let subscriber_registry = SubscriberRegistry::new(); + Ok(Self { quic_endpoints: endpoints, announce_url: config.announce, @@ -114,6 +119,7 @@ impl Relay { locals, remotes: Some(remotes), coordinator: config.coordinator, + subscriber_registry, }) } @@ -219,6 +225,7 @@ impl Relay { let remotes = remotes.clone(); let forward = forward_producer.clone(); let coordinator = self.coordinator.clone(); + let subscriber_registry = self.subscriber_registry.clone(); // Spawn a new task to handle the connection tasks.push(async move { @@ -232,11 +239,36 @@ impl Relay { }; // Create our MoQ relay session + // Use connection_id hash as session_id for self-exclusion in pub/sub + use std::hash::{Hash, Hasher}; + let session_id = { + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + connection_id.hash(&mut hasher); + hasher.finish() + }; + let moq_session = session; let session = Session { session: moq_session, - producer: publisher.map(|publisher| Producer::new(publisher, locals.clone(), remotes)), - consumer: subscriber.map(|subscriber| Consumer::new(subscriber, locals, coordinator, forward)), + producer: publisher.map(|publisher| { + Producer::with_registry( + publisher, + locals.clone(), + remotes, + subscriber_registry.clone(), + session_id, + ) + }), + consumer: subscriber.map(|subscriber| { + Consumer::with_registry( + subscriber, + locals, + coordinator, + forward, + subscriber_registry, + session_id, + ) + }), }; if let Err(err) = session.run().await { diff --git a/moq-relay-ietf/src/subscriber_registry.rs b/moq-relay-ietf/src/subscriber_registry.rs new file mode 100644 index 00000000..52c8aec5 --- /dev/null +++ b/moq-relay-ietf/src/subscriber_registry.rs @@ -0,0 +1,325 @@ +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +use moq_transport::coding::TrackNamespace; +use tokio::sync::broadcast; + +/// Information about an active SUBSCRIBE_NAMESPACE subscription +#[derive(Clone)] +pub struct NamespaceSubscription { + /// The namespace prefix this subscription is for + pub prefix: TrackNamespace, + /// Session ID of the subscriber (for self-exclusion) + pub session_id: u64, + /// Channel to send PUBLISH notifications to this subscriber + pub publish_tx: broadcast::Sender, + /// Channel to send PUBLISH_NAMESPACE notifications to this subscriber + pub publish_ns_tx: broadcast::Sender, +} + +/// Notification sent when a PUBLISH arrives that matches a subscription +#[derive(Clone, Debug)] +pub struct PublishNotification { + pub namespace: TrackNamespace, + pub track_name: String, + pub track_alias: u64, +} + +/// Notification sent when a PUBLISH_NAMESPACE arrives that matches a subscription +#[derive(Clone, Debug)] +pub struct PublishNamespaceNotification { + pub namespace: TrackNamespace, +} + +/// Registry for tracking active SUBSCRIBE_NAMESPACE subscriptions +/// +/// When a subscriber sends SUBSCRIBE_NAMESPACE, they register here. +/// When a publisher sends PUBLISH, we find matching subscriptions and notify. +#[derive(Clone)] +pub struct SubscriberRegistry { + inner: Arc>, +} + +struct SubscriberRegistryInner { + /// Map from subscription ID to subscription info + subscriptions: HashMap, + /// Next subscription ID + next_id: u64, +} + +impl SubscriberRegistry { + pub fn new() -> Self { + Self { + inner: Arc::new(Mutex::new(SubscriberRegistryInner { + subscriptions: HashMap::new(), + next_id: 0, + })), + } + } + + /// Register a SUBSCRIBE_NAMESPACE subscription + /// Returns (subscription_id, receiver for PUBLISH notifications, receiver for PUBLISH_NAMESPACE notifications) + pub fn register( + &self, + prefix: TrackNamespace, + session_id: u64, + ) -> ( + u64, + broadcast::Receiver, + broadcast::Receiver, + ) { + let mut inner = self.inner.lock().unwrap(); + + let id = inner.next_id; + inner.next_id += 1; + + // Create broadcast channels for PUBLISH and PUBLISH_NAMESPACE notifications + let (publish_tx, publish_rx) = broadcast::channel(64); + let (publish_ns_tx, publish_ns_rx) = broadcast::channel(64); + + let subscription = NamespaceSubscription { + prefix, + session_id, + publish_tx, + publish_ns_tx, + }; + + inner.subscriptions.insert(id, subscription); + + log::debug!( + "registered namespace subscription id={} session_id={}", + id, + session_id + ); + + (id, publish_rx, publish_ns_rx) + } + + /// Unregister a subscription + pub fn unregister(&self, id: u64) { + let mut inner = self.inner.lock().unwrap(); + if inner.subscriptions.remove(&id).is_some() { + log::debug!("unregistered namespace subscription id={}", id); + } + } + + /// Find all subscriptions that match a given namespace and notify them of a PUBLISH + /// Excludes the session that originated the PUBLISH (self-exclusion) + /// Returns the number of matching subscriptions notified + pub fn notify_publish( + &self, + namespace: &TrackNamespace, + track_name: &str, + track_alias: u64, + origin_session_id: u64, + ) -> usize { + let inner = self.inner.lock().unwrap(); + + let notification = PublishNotification { + namespace: namespace.clone(), + track_name: track_name.to_string(), + track_alias, + }; + + let mut notified = 0; + + for (id, sub) in inner.subscriptions.iter() { + // Skip if this subscription belongs to the same session that sent the PUBLISH + if sub.session_id == origin_session_id { + log::debug!( + "skipping subscription id={} (same session {})", + id, + origin_session_id + ); + continue; + } + + // Check if the namespace matches the subscription prefix + // The subscription prefix should be a prefix of the namespace + if Self::prefix_matches(&sub.prefix, namespace) { + if let Err(e) = sub.publish_tx.send(notification.clone()) { + log::warn!("failed to notify subscription id={}: {}", id, e); + } else { + log::debug!( + "notified subscription id={} of PUBLISH {}/{}", + id, + namespace, + track_name + ); + notified += 1; + } + } + } + + notified + } + + /// Find all subscriptions that match a given namespace and notify them of a PUBLISH_NAMESPACE + /// Excludes the session that originated the PUBLISH_NAMESPACE (self-exclusion) + /// Returns the number of matching subscriptions notified + pub fn notify_publish_namespace(&self, namespace: &TrackNamespace, origin_session_id: u64) -> usize { + let inner = self.inner.lock().unwrap(); + + let notification = PublishNamespaceNotification { + namespace: namespace.clone(), + }; + + let mut notified = 0; + + for (id, sub) in inner.subscriptions.iter() { + // Skip if this subscription belongs to the same session that sent the PUBLISH_NAMESPACE + if sub.session_id == origin_session_id { + log::debug!( + "skipping subscription id={} for PUBLISH_NAMESPACE (same session {})", + id, + origin_session_id + ); + continue; + } + + // Check if the namespace matches the subscription prefix + if Self::prefix_matches(&sub.prefix, namespace) { + if let Err(e) = sub.publish_ns_tx.send(notification.clone()) { + log::warn!( + "failed to notify subscription id={} of PUBLISH_NAMESPACE: {}", + id, + e + ); + } else { + log::debug!( + "notified subscription id={} of PUBLISH_NAMESPACE {:?}", + id, + namespace + ); + notified += 1; + } + } + } + + notified + } + + /// Check if prefix is a prefix of namespace + fn prefix_matches(prefix: &TrackNamespace, namespace: &TrackNamespace) -> bool { + if prefix.fields.len() > namespace.fields.len() { + return false; + } + + prefix + .fields + .iter() + .zip(namespace.fields.iter()) + .all(|(a, b)| a == b) + } + + /// Get all subscriptions matching a prefix (for debugging) + pub fn matching_subscriptions(&self, namespace: &TrackNamespace) -> Vec { + let inner = self.inner.lock().unwrap(); + + inner + .subscriptions + .iter() + .filter(|(_, sub)| Self::prefix_matches(&sub.prefix, namespace)) + .map(|(id, _)| *id) + .collect() + } +} + +impl Default for SubscriberRegistry { + fn default() -> Self { + Self::new() + } +} + +/// RAII guard that unregisters on drop +pub struct SubscriptionGuard { + registry: SubscriberRegistry, + id: u64, +} + +impl SubscriptionGuard { + pub fn new(registry: SubscriberRegistry, id: u64) -> Self { + Self { registry, id } + } + + pub fn id(&self) -> u64 { + self.id + } +} + +impl Drop for SubscriptionGuard { + fn drop(&mut self) { + self.registry.unregister(self.id); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn ns(path: &str) -> TrackNamespace { + TrackNamespace::from_utf8_path(path) + } + + #[test] + fn test_prefix_matching() { + assert!(SubscriberRegistry::prefix_matches(&ns("live"), &ns("live/stream1"))); + assert!(SubscriberRegistry::prefix_matches(&ns("live"), &ns("live"))); + // An empty prefix (zero fields) should match everything + let empty = TrackNamespace::new(); + assert!(SubscriberRegistry::prefix_matches(&empty, &ns("live/stream1"))); + assert!(!SubscriberRegistry::prefix_matches(&ns("live/stream1"), &ns("live"))); + assert!(!SubscriberRegistry::prefix_matches(&ns("other"), &ns("live/stream1"))); + } + + #[test] + fn test_register_unregister() { + let registry = SubscriberRegistry::new(); + + let (id1, _rx1, _rx1_ns) = registry.register(ns("live"), 100); + let (id2, _rx2, _rx2_ns) = registry.register(ns("live/room1"), 101); + + assert_eq!(registry.matching_subscriptions(&ns("live/room1/track")).len(), 2); + + registry.unregister(id1); + + assert_eq!(registry.matching_subscriptions(&ns("live/room1/track")).len(), 1); + + registry.unregister(id2); + + assert_eq!(registry.matching_subscriptions(&ns("live/room1/track")).len(), 0); + } + + #[tokio::test] + async fn test_notify_publish() { + let registry = SubscriberRegistry::new(); + + // Register with session_id=100 + let (id, mut rx, _rx_ns) = registry.register(ns("live"), 100); + + // Notify from session 200 (different) - should be delivered + let notified = registry.notify_publish(&ns("live/stream1"), "video", 42, 200); + assert_eq!(notified, 1); + + let notification = rx.recv().await.unwrap(); + assert_eq!(notification.track_name, "video"); + assert_eq!(notification.track_alias, 42); + + registry.unregister(id); + } + + #[tokio::test] + async fn test_self_exclusion() { + let registry = SubscriberRegistry::new(); + + // Register with session_id=100 + let (_id, mut rx, _rx_ns) = registry.register(ns("live"), 100); + + // Notify from the same session (100) - should NOT be delivered + let notified = registry.notify_publish(&ns("live/stream1"), "video", 42, 100); + assert_eq!(notified, 0); + + // Verify nothing was received (use try_recv to avoid blocking) + assert!(rx.try_recv().is_err()); + } +} diff --git a/moq-sub/src/media.rs b/moq-sub/src/media.rs index e3fd43c5..5503d9eb 100644 --- a/moq-sub/src/media.rs +++ b/moq-sub/src/media.rs @@ -183,16 +183,54 @@ impl Media { async fn recv_group(mut group: SubgroupReader, out: Arc>) -> anyhow::Result<()> { trace!("group={} start", group.group_id); + + // Pair moof+mdat into a single atomic write to prevent concurrent + // audio/video tasks from interleaving between them on stdout. + let mut pending_moof: Option> = None; + while let Some(object) = group.next().await? { trace!( "group={} fragment={} start", group.group_id, object.object_id ); - let out = out.clone(); let buf = Self::recv_object(object).await?; - out.lock().await.write_all(&buf).await?; + let is_moof = buf.len() >= 8 && &buf[4..8] == b"moof"; + let is_mdat = buf.len() >= 8 && &buf[4..8] == b"mdat"; + + if is_moof { + if let Some(orphan) = pending_moof.take() { + warn!( + "group={}: flushing orphaned moof ({} bytes) without mdat", + group.group_id, + orphan.len() + ); + out.lock().await.write_all(&orphan).await?; + } + pending_moof = Some(buf); + } else if is_mdat { + if let Some(mut moof) = pending_moof.take() { + moof.extend_from_slice(&buf); + out.lock().await.write_all(&moof).await?; + } else { + warn!( + "group={}: mdat without preceding moof ({} bytes)", + group.group_id, + buf.len() + ); + out.lock().await.write_all(&buf).await?; + } + } else { + if let Some(orphan) = pending_moof.take() { + out.lock().await.write_all(&orphan).await?; + } + out.lock().await.write_all(&buf).await?; + } + } + + if let Some(orphan) = pending_moof.take() { + out.lock().await.write_all(&orphan).await?; } Ok(()) diff --git a/moq-test-client/Cargo.toml b/moq-test-client/Cargo.toml index f4a4db17..37840a02 100644 --- a/moq-test-client/Cargo.toml +++ b/moq-test-client/Cargo.toml @@ -18,7 +18,7 @@ path = "src/main.rs" [dependencies] moq-transport = { path = "../moq-transport", version = "0.12" } moq-native-ietf = { path = "../moq-native-ietf", version = "0.7" } -web-transport = "0.3" +web-transport = { workspace = true } url = "2" diff --git a/moq-test-client/src/main.rs b/moq-test-client/src/main.rs index 9438935b..7b925a66 100644 --- a/moq-test-client/src/main.rs +++ b/moq-test-client/src/main.rs @@ -140,9 +140,9 @@ async fn run_test(args: &Args, test_case: TestCase) -> TestResult { let result = match test_case { TestCase::SetupOnly => scenarios::test_setup_only(args).await, - TestCase::AnnounceOnly => scenarios::test_announce_only(args).await, + TestCase::AnnounceOnly => scenarios::test_publish_namespace_only(args).await, TestCase::SubscribeError => scenarios::test_subscribe_error(args).await, - TestCase::AnnounceSubscribe => scenarios::test_announce_subscribe(args).await, + TestCase::AnnounceSubscribe => scenarios::test_publish_namespace_subscribe(args).await, TestCase::SubscribeBeforeAnnounce => scenarios::test_subscribe_before_announce(args).await, TestCase::PublishNamespaceDone => scenarios::test_publish_namespace_done(args).await, }; diff --git a/moq-test-client/src/scenarios.rs b/moq-test-client/src/scenarios.rs index ce6a923c..6752c836 100644 --- a/moq-test-client/src/scenarios.rs +++ b/moq-test-client/src/scenarios.rs @@ -10,7 +10,11 @@ use anyhow::{Context, Result}; use tokio::time::{timeout, Duration}; use moq_native_ietf::quic; -use moq_transport::{coding::TrackNamespace, serve::Tracks, session::Session}; +use moq_transport::{ + coding::TrackNamespace, + serve::Tracks, + session::{Publisher, Session}, +}; use crate::Args; @@ -20,7 +24,7 @@ const TEST_TIMEOUT: Duration = Duration::from_secs(10); /// Namespace used for test operations const TEST_NAMESPACE: &str = "moq-test/interop"; -/// Track name used for test operations +/// Track name used for test operations const TEST_TRACK: &str = "test-track"; /// Helper to connect to a relay and establish a session @@ -72,10 +76,10 @@ pub async fn test_setup_only(args: &Args) -> Result { .context("test timed out")? } -/// T0.2: Announce Only +/// T0.2: Publish namespace Only /// -/// Connect to relay, announce a namespace, receive PUBLISH_NAMESPACE_OK, close. -pub async fn test_announce_only(args: &Args) -> Result { +/// Connect to relay, publish a namespace, receive PUBLISH_NAMESPACE_OK, close. +pub async fn test_publish_namespace_only(args: &Args) -> Result { timeout(TEST_TIMEOUT, async { let (session, cid) = connect(args).await.context("failed to connect to relay")?; let mut cids = TestConnectionIds::default(); @@ -86,32 +90,31 @@ pub async fn test_announce_only(args: &Args) -> Result { .context("SETUP exchange failed")?; let namespace = TrackNamespace::from_utf8_path(TEST_NAMESPACE); - let (_, _, reader) = Tracks::new(namespace.clone()).produce(); - log::info!("Announcing namespace: {}", TEST_NAMESPACE); + log::info!("Publishing namespace: {}", TEST_NAMESPACE); - // Run announce with a timeout - we want to verify we get PUBLISH_NAMESPACE_OK. - // NOTE: The announce() method blocks waiting for subscriptions after getting OK. + // Run publish namespace with a timeout - we want to verify we get PUBLISH_NAMESPACE_OK. + // NOTE: The publish_namespace() method sends PUBLISH_NAMESPACE and wait for OK or ERROR. // If we get PUBLISH_NAMESPACE_ERROR instead of OK, the method returns Err immediately. - // So timing out here means: either (a) got OK and waiting for subs, or (b) relay never responded. - // We accept this limitation since (b) would indicate a broken relay anyway. - // TODO: For stricter verification, use lower-level Announce::ok() method directly. - let announce_result = tokio::select! { - res = publisher.announce(reader) => res, + // So timing out here means relay never responded and connection may be broken. + let publish_ns = publisher.publish_namespace(namespace).await?; + + let publish_ns_result = tokio::select! { + res = publish_ns.ok() => res, res = session.run() => { res.context("session error")?; anyhow::bail!("session ended before announce completed"); } _ = tokio::time::sleep(Duration::from_secs(2)) => { // If we got an error from the relay, announce() would have returned already. - // Timing out means we're past the OK and now waiting for subscriptions. - log::info!("Announce succeeded (no error received, waiting for subscriptions timed out)"); - return Ok(cids); + // Timing out means the relay never responded and connection may be broken. + log::info!("Publishing namespace failed (relay did not reply)"); + return Err(anyhow::anyhow!("publish namespace timed out")); } }; - // If we get here, announce completed (which means it errored or namespace was cancelled) - announce_result.context("announce failed")?; + // If we get here, publish namespace completed (which means it errored or namespace was cancelled) + publish_ns_result.context("publish namespace failed")?; Ok(cids) }) @@ -190,11 +193,11 @@ pub async fn test_subscribe_error(args: &Args) -> Result { .context("test timed out")? } -/// T0.4: Announce + Subscribe +/// T0.4: Publish Namespace + Subscribe /// -/// Two clients: publisher announces a namespace, subscriber subscribes to a track. +/// Two clients: publisher publishes a namespace, subscriber subscribes to a track. /// Verifies the relay correctly routes the subscription to the publisher. -pub async fn test_announce_subscribe(args: &Args) -> Result { +pub async fn test_publish_namespace_subscribe(args: &Args) -> Result { timeout(TEST_TIMEOUT, async { let mut cids = TestConnectionIds::default(); @@ -222,7 +225,12 @@ pub async fn test_announce_subscribe(args: &Args) -> Result { // Create the track that subscriber will request let _track_writer = pub_writer.create(TEST_TRACK); - log::info!("Publisher announcing namespace: {}", TEST_NAMESPACE); + log::info!("Publisher publishing namespace: {}", TEST_NAMESPACE); + + let publish_ns = publisher + .publish_namespace(namespace.clone()) + .await + .context("publish_namespace call failed")?; // Subscriber: set up tracks and subscribe let (mut sub_writer, _, _sub_reader) = Tracks::new(namespace.clone()).produce(); @@ -230,37 +238,51 @@ pub async fn test_announce_subscribe(args: &Args) -> Result { .create(TEST_TRACK) .ok_or_else(|| anyhow::anyhow!("failed to create subscriber track"))?; - log::info!( - "Subscriber subscribing to track: {}/{}", - TEST_NAMESPACE, - TEST_TRACK - ); - - // Run everything concurrently. We expect the subscriber to get a response - // (either SUBSCRIBE_OK or error) within the timeout. + // Run everything concurrently. Session::run() consumes self, so + // publish_namespace→subscribe must be sequenced inside a single async + // block running alongside both sessions. + let mut pub_subscriber_handler = publisher.clone(); tokio::select! { - // Publisher announces and waits for subscriptions - res = publisher.announce(pub_reader) => { - res.context("publisher announce failed")?; - log::info!("Publisher announce completed"); - } - // Subscriber subscribes - this is the main thing we're testing - res = subscriber.subscribe(sub_track) => { - match res { + // Publisher publishes namespace, then subscriber subscribes + res = async { + publish_ns.ok().await.context("publish namespace failed")?; + log::info!("Publisher got PUBLISH_NAMESPACE_OK"); + + log::info!("Subscribing to track: {}/{}", TEST_NAMESPACE, TEST_TRACK); + // Subscriber subscribes - this is the main thing we're testing + match subscriber.subscribe(sub_track).await { Ok(()) => log::info!("Subscriber got SUBSCRIBE_OK - relay routed subscription correctly"), Err(e) => log::info!("Subscriber got error: {} - subscription was processed", e), } + Ok::<_, anyhow::Error>(()) + } => { + res?; + } + // Serve incoming subscriptions forwarded by the relay to the publisher + res = async { + while let Some(subscribed) = pub_subscriber_handler.subscribed().await { + let info = subscribed.info.clone(); + log::info!("Publisher serving subscribe: {:?}", info); + if let Err(err) = Publisher::serve_subscribe(subscribed, pub_reader.clone()).await { + log::warn!("Failed serving subscribe: {:?}, error: {}", info, err); + } + } + Ok::<_, anyhow::Error>(()) + } => { + res?; } // Run publisher session res = pub_session.run() => { res.context("publisher session error")?; + anyhow::bail!("publisher session ended unexpectedly"); } // Run subscriber session res = sub_session.run() => { res.context("subscriber session error")?; + anyhow::bail!("subscriber session ended unexpectedly"); } // Timeout: give the relay time to route the subscription - _ = tokio::time::sleep(Duration::from_secs(3)) => { + _ = tokio::time::sleep(Duration::from_secs(5)) => { // If we hit this timeout, the subscription may still be pending. // This isn't necessarily a failure - some relays may hold subscriptions // until the track has data. Log for visibility. @@ -289,27 +311,30 @@ pub async fn test_publish_namespace_done(args: &Args) -> Result res, + let publish_ns = publisher.publish_namespace(namespace).await?; + + let publish_ns_result = tokio::select! { + res = publish_ns.ok() => res, res = session.run() => { res.context("session error")?; anyhow::bail!("session ended before announce completed"); } _ = tokio::time::sleep(Duration::from_secs(2)) => { - // No error received - announce is active and waiting for subscriptions - log::info!("Announce active, now sending PUBLISH_NAMESPACE_DONE"); - // Dropping out of this block will drop the announce, which sends PUBLISH_NAMESPACE_DONE - Ok(()) + // If we got an error from the relay, announce() would have returned already. + // Timing out means the relay never responded and connection may be broken. + log::info!("Publishing namespace failed (relay did not reply)"); + return Err(anyhow::anyhow!("publish namespace timed out")); } }; - result.context("announce failed")?; + publish_ns_result.context("publish namespace failed")?; + + drop(publish_ns); // Small delay to ensure PUBLISH_NAMESPACE_DONE is sent before we close tokio::time::sleep(Duration::from_millis(100)).await; @@ -374,17 +399,16 @@ pub async fn test_subscribe_before_announce(args: &Args) -> Result { + res = publish_ns.ok() => { res.context("publisher announce failed")?; } res = pub_session.run() => { diff --git a/moq-transport/src/coding/kvp.rs b/moq-transport/src/coding/kvp.rs index 2ed9caa9..065f5d39 100644 --- a/moq-transport/src/coding/kvp.rs +++ b/moq-transport/src/coding/kvp.rs @@ -48,13 +48,46 @@ impl KeyValuePair { value: Value::BytesValue(value), } } -} -impl Decode for KeyValuePair { - fn decode(r: &mut R) -> Result { - let key = u64::decode(r)?; + /// Validate that the key parity matches the value type. + /// Even keys => IntValue, Odd keys => BytesValue. + fn validate_key_parity(&self) -> Result<(), EncodeError> { + match &self.value { + Value::IntValue(_) => { + if !self.key.is_multiple_of(2) { + return Err(EncodeError::InvalidValue); + } + } + Value::BytesValue(_) => { + if self.key.is_multiple_of(2) { + return Err(EncodeError::InvalidValue); + } + } + } + Ok(()) + } - if key % 2 == 0 { + /// Encode only the value portion of this KVP (not the key/delta). + /// The caller is responsible for encoding the key or delta type. + pub(crate) fn encode_value(&self, w: &mut W) -> Result<(), EncodeError> { + self.validate_key_parity()?; + match &self.value { + Value::IntValue(v) => { + (*v).encode(w)?; + } + Value::BytesValue(v) => { + v.len().encode(w)?; + Self::encode_remaining(w, v.len())?; + w.put_slice(v); + } + } + Ok(()) + } + + /// Decode only the value portion of a KVP given the absolute key. + /// The caller has already decoded the key/delta and resolved the absolute key. + pub(crate) fn decode_value(key: u64, r: &mut R) -> Result { + if key.is_multiple_of(2) { // VarInt variant let value = u64::decode(r)?; log::trace!("[KVP] Decoded even key={}, value={}", key, value); @@ -81,30 +114,22 @@ impl Decode for KeyValuePair { } } +/// Legacy Decode for KeyValuePair — reads absolute key from wire. +/// Used only by ExtensionHeaders which reads KVPs from a bounded byte slice. +impl Decode for KeyValuePair { + fn decode(r: &mut R) -> Result { + let key = u64::decode(r)?; + Self::decode_value(key, r) + } +} + +/// Legacy Encode for KeyValuePair — writes absolute key to wire. +/// Used only by ExtensionHeaders which writes KVPs into a temporary buffer. impl Encode for KeyValuePair { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - match &self.value { - Value::IntValue(v) => { - // key must be even for IntValue - if !self.key.is_multiple_of(2) { - return Err(EncodeError::InvalidValue); - } - self.key.encode(w)?; - (*v).encode(w)?; - Ok(()) - } - Value::BytesValue(v) => { - // key must be odd for BytesValue - if self.key.is_multiple_of(2) { - return Err(EncodeError::InvalidValue); - } - self.key.encode(w)?; - v.len().encode(w)?; - Self::encode_remaining(w, v.len())?; - w.put_slice(v); - Ok(()) - } - } + self.validate_key_parity()?; + self.key.encode(w)?; + self.encode_value(w) } } @@ -116,7 +141,10 @@ impl fmt::Debug for KeyValuePair { /// A collection of KeyValuePair entries, where the number of key-value-pairs are encoded/decoded first. /// This structure is appropriate for Control message parameters. -/// Since duplicate parameters are allowed for unknown parameters, we don't do duplicate checking here. +/// +/// Per draft-16 Section 1.4.2, Key-Value-Pairs use delta-encoded Type fields: +/// each Type is encoded as a delta from the previous Type (or from 0 for the first). +/// Entries are sorted by key (Type) in ascending order for encoding. #[derive(Default, Clone, Eq, PartialEq)] pub struct KeyValuePairs(pub Vec); @@ -150,16 +178,49 @@ impl KeyValuePairs { pub fn get(&self, key: u64) -> Option<&KeyValuePair> { self.0.iter().find(|k| k.key == key) } + + /// Get an integer value by key, returning None if not found or if the value is not an integer + pub fn get_intvalue(&self, key: u64) -> Option { + self.get(key).and_then(|kvp| match &kvp.value { + Value::IntValue(v) => Some(*v), + Value::BytesValue(_) => None, + }) + } + + /// Get a bytes value by key, returning None if not found or if the value is not bytes + pub fn get_bytesvalue(&self, key: u64) -> Option<&Vec> { + self.get(key).and_then(|kvp| match &kvp.value { + Value::IntValue(_) => None, + Value::BytesValue(v) => Some(v), + }) + } } impl Decode for KeyValuePairs { - fn decode(mut r: &mut R) -> Result { + /// Decode Key-Value-Pairs with delta-encoded Type fields (draft-16 Section 1.4.2). + fn decode(r: &mut R) -> Result { let mut kvps = Vec::new(); let count = u64::decode(r)?; + let mut prev_key: u64 = 0; + for _ in 0..count { - let kvp = KeyValuePair::decode(&mut r)?; + // Read delta type + let delta = u64::decode(r)?; + + // Reconstruct absolute key: prev_key + delta + let key = prev_key.checked_add(delta).ok_or_else(|| { + log::error!( + "[KVP] Delta type overflow: prev_key={}, delta={}", + prev_key, + delta + ); + DecodeError::BoundsExceeded(crate::coding::BoundsExceeded) + })?; + + let kvp = KeyValuePair::decode_value(key, r)?; kvps.push(kvp); + prev_key = key; } Ok(KeyValuePairs(kvps)) @@ -167,11 +228,32 @@ impl Decode for KeyValuePairs { } impl Encode for KeyValuePairs { + /// Encode Key-Value-Pairs with delta-encoded Type fields (draft-16 Section 1.4.2). + /// Entries are sorted by key in ascending order before encoding. fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.0.len().encode(w)?; - for kvp in &self.0 { - kvp.encode(w)?; + // Sort by key for delta encoding (Types must be in ascending order) + let mut sorted: Vec<&KeyValuePair> = self.0.iter().collect(); + sorted.sort_by_key(|kvp| kvp.key); + + let mut prev_key: u64 = 0; + for kvp in sorted { + // Compute and encode the delta + let delta = kvp.key.checked_sub(prev_key).ok_or_else(|| { + log::error!( + "[KVP] Keys not sortable: prev_key={}, current_key={}", + prev_key, + kvp.key + ); + EncodeError::InvalidValue + })?; + delta.encode(w)?; + + // Encode the value (without the key) + kvp.encode_value(w)?; + + prev_key = kvp.key; } Ok(()) @@ -243,9 +325,10 @@ mod tests { } #[test] - fn encode_decode_keyvaluepairs() { + fn encode_decode_keyvaluepairs_single() { let mut buf = BytesMut::new(); + // Single entry: key=1 (odd, bytes). Delta from 0 = 1. let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); kvps.encode(&mut buf).unwrap(); @@ -253,21 +336,79 @@ mod tests { buf.to_vec(), vec![ 0x01, // 1 KeyValuePair - 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, // Key=1, Value=[1,2,3,4,5] + // Delta=1 (from 0), then length=5, then data + 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, ] ); let decoded = KeyValuePairs::decode(&mut buf).unwrap(); assert_eq!(decoded, kvps); + } + #[test] + fn encode_decode_keyvaluepairs_multiple() { + let mut buf = BytesMut::new(); + + // Multiple entries inserted out of order — encoding should sort by key. + // Keys: 0 (even, int), 1 (odd, bytes), 100 (even, int) let mut kvps = KeyValuePairs::new(); kvps.set_intvalue(0, 0); kvps.set_intvalue(100, 100); kvps.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); kvps.encode(&mut buf).unwrap(); - let buf_vec = buf.to_vec(); - // Validate the encoded length and the KeyValuePair count - assert_eq!(14, buf_vec.len()); // 14 bytes total - assert_eq!(3, buf_vec[0]); // 3 KeyValuePairs + + #[rustfmt::skip] + let expected = vec![ + 0x03, // 3 KeyValuePairs + // Entry 1: key=0 (delta=0 from 0), even, int value=0 + 0x00, 0x00, + // Entry 2: key=1 (delta=1 from 0), odd, bytes len=5 + 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, + // Entry 3: key=100 (delta=99 from 1), even, int value=100 + 0x40, 0x63, 0x40, 0x64, + ]; + assert_eq!(buf.to_vec(), expected); + + // Decode and verify — decoded entries will be in sorted order + let decoded = KeyValuePairs::decode(&mut buf).unwrap(); + // Build expected sorted kvps for comparison + let mut expected_kvps = KeyValuePairs::new(); + expected_kvps.set_intvalue(0, 0); + expected_kvps.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); + expected_kvps.set_intvalue(100, 100); + assert_eq!(decoded, expected_kvps); + } + + #[test] + fn encode_decode_keyvaluepairs_roundtrip_sorted() { + let mut buf = BytesMut::new(); + + // Insert in sorted order — should roundtrip exactly + let mut kvps = KeyValuePairs::new(); + kvps.set_intvalue(2, 42); + kvps.set_intvalue(4, 100); + kvps.encode(&mut buf).unwrap(); + + #[rustfmt::skip] + let expected = vec![ + 0x02, // 2 KeyValuePairs + // Entry 1: key=2 (delta=2), int value=42 + 0x02, 0x2a, + // Entry 2: key=4 (delta=2 from 2), int value=100 + 0x02, 0x40, 0x64, + ]; + assert_eq!(buf.to_vec(), expected); + + let decoded = KeyValuePairs::decode(&mut buf).unwrap(); + assert_eq!(decoded, kvps); + } + + #[test] + fn encode_decode_keyvaluepairs_empty() { + let mut buf = BytesMut::new(); + + let kvps = KeyValuePairs::new(); + kvps.encode(&mut buf).unwrap(); + assert_eq!(buf.to_vec(), vec![0x00]); // count=0 let decoded = KeyValuePairs::decode(&mut buf).unwrap(); assert_eq!(decoded, kvps); } diff --git a/moq-transport/src/coding/mod.rs b/moq-transport/src/coding/mod.rs index 13e97be1..71cc4148 100644 --- a/moq-transport/src/coding/mod.rs +++ b/moq-transport/src/coding/mod.rs @@ -6,6 +6,7 @@ mod integer; mod kvp; mod location; mod string; +mod track_extensions; mod track_namespace; mod tuple; mod varint; @@ -16,6 +17,7 @@ pub use encode::*; pub use hex_dump::*; pub use kvp::*; pub use location::*; +pub use track_extensions::*; pub use track_namespace::*; pub use tuple::*; pub use varint::*; diff --git a/moq-transport/src/coding/track_extensions.rs b/moq-transport/src/coding/track_extensions.rs new file mode 100644 index 00000000..3da50c55 --- /dev/null +++ b/moq-transport/src/coding/track_extensions.rs @@ -0,0 +1,196 @@ +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePair, Value}; +use std::fmt; + +/// A collection of KeyValuePair entries for Track Extensions. +/// Per draft-16 Section 9.10, Track Extensions are encoded WITHOUT a count or length prefix. +/// They are simply a sequence of delta-encoded key-value pairs until end of message. +/// +/// This differs from: +/// - KeyValuePairs: has a count prefix +/// - ExtensionHeaders: has a byte-length prefix +#[derive(Default, Clone, Eq, PartialEq)] +pub struct TrackExtensions(pub Vec); + +impl TrackExtensions { + pub fn new() -> Self { + Self::default() + } + + /// Insert or replace a KeyValuePair with the same key. + pub fn set(&mut self, kvp: KeyValuePair) { + if let Some(existing) = self.0.iter_mut().find(|k| k.key == kvp.key) { + *existing = kvp; + } else { + self.0.push(kvp); + } + } + + pub fn set_intvalue(&mut self, key: u64, value: u64) { + self.set(KeyValuePair::new_int(key, value)); + } + + pub fn set_bytesvalue(&mut self, key: u64, value: Vec) { + self.set(KeyValuePair::new_bytes(key, value)); + } + + pub fn has(&self, key: u64) -> bool { + self.0.iter().any(|k| k.key == key) + } + + pub fn get(&self, key: u64) -> Option<&KeyValuePair> { + self.0.iter().find(|k| k.key == key) + } + + /// Get an integer value by key, returning None if not found or if the value is not an integer + pub fn get_intvalue(&self, key: u64) -> Option { + self.get(key).and_then(|kvp| match &kvp.value { + Value::IntValue(v) => Some(*v), + Value::BytesValue(_) => None, + }) + } + + /// Get a bytes value by key, returning None if not found or if the value is not bytes + pub fn get_bytesvalue(&self, key: u64) -> Option<&Vec> { + self.get(key).and_then(|kvp| match &kvp.value { + Value::IntValue(_) => None, + Value::BytesValue(v) => Some(v), + }) + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl Decode for TrackExtensions { + /// Decode Track Extensions - reads delta-encoded key-value pairs until end of buffer. + /// Per draft-16, Track Extensions have NO count or length prefix. + fn decode(r: &mut R) -> Result { + let mut kvps = Vec::new(); + let mut prev_key: u64 = 0; + + // Read until buffer is exhausted + while r.has_remaining() { + // Read delta type + let delta = u64::decode(r)?; + + // Reconstruct absolute key: prev_key + delta + let key = prev_key.checked_add(delta).ok_or_else(|| { + log::error!( + "[TrackExt] Delta type overflow: prev_key={}, delta={}", + prev_key, + delta + ); + DecodeError::BoundsExceeded(crate::coding::BoundsExceeded) + })?; + + let kvp = KeyValuePair::decode_value(key, r)?; + kvps.push(kvp); + prev_key = key; + } + + Ok(TrackExtensions(kvps)) + } +} + +impl Encode for TrackExtensions { + /// Encode Track Extensions - writes delta-encoded key-value pairs WITHOUT any prefix. + /// Entries are sorted by key in ascending order before encoding. + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + // Sort by key for delta encoding (Types must be in ascending order) + let mut sorted: Vec<&KeyValuePair> = self.0.iter().collect(); + sorted.sort_by_key(|kvp| kvp.key); + + let mut prev_key: u64 = 0; + for kvp in sorted { + // Compute and encode the delta + let delta = kvp.key.checked_sub(prev_key).ok_or_else(|| { + log::error!( + "[TrackExt] Keys not sortable: prev_key={}, current_key={}", + prev_key, + kvp.key + ); + EncodeError::InvalidValue + })?; + delta.encode(w)?; + + // Encode the value (without the key) + kvp.encode_value(w)?; + + prev_key = kvp.key; + } + + Ok(()) + } +} + +impl fmt::Debug for TrackExtensions { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{{ ")?; + for (i, kv) in self.0.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}", kv)?; + } + write!(f, " }}") + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BytesMut; + + #[test] + fn encode_decode_empty() { + let mut buf = BytesMut::new(); + + let ext = TrackExtensions::new(); + ext.encode(&mut buf).unwrap(); + // Empty TrackExtensions produces NO bytes (no prefix!) + let expected: Vec = vec![]; + assert_eq!(buf.to_vec(), expected); + let decoded = TrackExtensions::decode(&mut buf).unwrap(); + assert_eq!(decoded, ext); + } + + #[test] + fn encode_decode_single() { + let mut buf = BytesMut::new(); + + let mut ext = TrackExtensions::new(); + ext.set_intvalue(2, 42); // key=2 (even), value=42 + ext.encode(&mut buf).unwrap(); + + // Expected: delta=2, value=42 (no count or length prefix!) + assert_eq!(buf.to_vec(), vec![0x02, 0x2a]); + + let decoded = TrackExtensions::decode(&mut buf).unwrap(); + assert_eq!(decoded, ext); + } + + #[test] + fn encode_decode_multiple() { + let mut buf = BytesMut::new(); + + let mut ext = TrackExtensions::new(); + ext.set_intvalue(0, 0); + ext.set_intvalue(2, 100); + ext.encode(&mut buf).unwrap(); + + // Expected: + // Entry 1: delta=0, value=0 + // Entry 2: delta=2 (from 0), value=100 + // No count prefix! + #[rustfmt::skip] + let expected = vec![ + 0x00, 0x00, // delta=0, value=0 + 0x02, 0x40, 0x64, // delta=2, value=100 (varint) + ]; + assert_eq!(buf.to_vec(), expected); + + let decoded = TrackExtensions::decode(&mut buf).unwrap(); + assert_eq!(decoded, ext); + } +} diff --git a/moq-transport/src/data/datagram.rs b/moq-transport/src/data/datagram.rs index 7a521319..61d36ce1 100644 --- a/moq-transport/src/data/datagram.rs +++ b/moq-transport/src/data/datagram.rs @@ -3,6 +3,7 @@ use crate::data::{ExtensionHeaders, ObjectStatus}; #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum DatagramType { + // Payload types with Priority Present (0x00-0x07) ObjectIdPayload = 0x00, ObjectIdPayloadExt = 0x01, ObjectIdPayloadEndOfGroup = 0x02, @@ -11,13 +12,125 @@ pub enum DatagramType { PayloadExt = 0x05, PayloadEndOfGroup = 0x06, PayloadExtEndOfGroup = 0x07, + // Payload types with Priority Not Present (0x08-0x0F) + ObjectIdPayloadNoPriority = 0x08, + ObjectIdPayloadExtNoPriority = 0x09, + ObjectIdPayloadEndOfGroupNoPriority = 0x0a, + ObjectIdPayloadExtEndOfGroupNoPriority = 0x0b, + PayloadNoPriority = 0x0c, + PayloadExtNoPriority = 0x0d, + PayloadEndOfGroupNoPriority = 0x0e, + PayloadExtEndOfGroupNoPriority = 0x0f, + // Status types with Priority Present (0x20-0x25) ObjectIdStatus = 0x20, ObjectIdStatusExt = 0x21, + Status = 0x24, + StatusExt = 0x25, + // Status types with Priority Not Present (0x28-0x2D) + ObjectIdStatusNoPriority = 0x28, + ObjectIdStatusExtNoPriority = 0x29, + StatusNoPriority = 0x2c, + StatusExtNoPriority = 0x2d, +} + +impl DatagramType { + /// Returns true if this datagram type has the Object ID field present + pub fn has_object_id(&self) -> bool { + matches!( + *self, + DatagramType::ObjectIdPayload + | DatagramType::ObjectIdPayloadExt + | DatagramType::ObjectIdPayloadEndOfGroup + | DatagramType::ObjectIdPayloadExtEndOfGroup + | DatagramType::ObjectIdPayloadNoPriority + | DatagramType::ObjectIdPayloadExtNoPriority + | DatagramType::ObjectIdPayloadEndOfGroupNoPriority + | DatagramType::ObjectIdPayloadExtEndOfGroupNoPriority + | DatagramType::ObjectIdStatus + | DatagramType::ObjectIdStatusExt + | DatagramType::ObjectIdStatusNoPriority + | DatagramType::ObjectIdStatusExtNoPriority + ) + } + + /// Returns true if this datagram type has the Publisher Priority field present + pub fn has_priority(&self) -> bool { + matches!( + *self, + DatagramType::ObjectIdPayload + | DatagramType::ObjectIdPayloadExt + | DatagramType::ObjectIdPayloadEndOfGroup + | DatagramType::ObjectIdPayloadExtEndOfGroup + | DatagramType::Payload + | DatagramType::PayloadExt + | DatagramType::PayloadEndOfGroup + | DatagramType::PayloadExtEndOfGroup + | DatagramType::ObjectIdStatus + | DatagramType::ObjectIdStatusExt + | DatagramType::Status + | DatagramType::StatusExt + ) + } + + /// Returns true if this datagram type has extension headers + pub fn has_extensions(&self) -> bool { + matches!( + *self, + DatagramType::ObjectIdPayloadExt + | DatagramType::ObjectIdPayloadExtEndOfGroup + | DatagramType::PayloadExt + | DatagramType::PayloadExtEndOfGroup + | DatagramType::ObjectIdPayloadExtNoPriority + | DatagramType::ObjectIdPayloadExtEndOfGroupNoPriority + | DatagramType::PayloadExtNoPriority + | DatagramType::PayloadExtEndOfGroupNoPriority + | DatagramType::ObjectIdStatusExt + | DatagramType::StatusExt + | DatagramType::ObjectIdStatusExtNoPriority + | DatagramType::StatusExtNoPriority + ) + } + + /// Returns true if this is a status datagram (no payload) + pub fn is_status(&self) -> bool { + matches!( + *self, + DatagramType::ObjectIdStatus + | DatagramType::ObjectIdStatusExt + | DatagramType::Status + | DatagramType::StatusExt + | DatagramType::ObjectIdStatusNoPriority + | DatagramType::ObjectIdStatusExtNoPriority + | DatagramType::StatusNoPriority + | DatagramType::StatusExtNoPriority + ) + } + + /// Returns true if this is a payload datagram + pub fn is_payload(&self) -> bool { + !self.is_status() + } + + /// Returns true if this datagram type indicates end of group + pub fn is_end_of_group(&self) -> bool { + matches!( + *self, + DatagramType::ObjectIdPayloadEndOfGroup + | DatagramType::ObjectIdPayloadExtEndOfGroup + | DatagramType::PayloadEndOfGroup + | DatagramType::PayloadExtEndOfGroup + | DatagramType::ObjectIdPayloadEndOfGroupNoPriority + | DatagramType::ObjectIdPayloadExtEndOfGroupNoPriority + | DatagramType::PayloadEndOfGroupNoPriority + | DatagramType::PayloadExtEndOfGroupNoPriority + ) + } } impl Decode for DatagramType { fn decode(r: &mut B) -> Result { match u64::decode(r)? { + // Payload types with Priority Present (0x00-0x07) 0x00 => Ok(Self::ObjectIdPayload), 0x01 => Ok(Self::ObjectIdPayloadExt), 0x02 => Ok(Self::ObjectIdPayloadEndOfGroup), @@ -26,8 +139,25 @@ impl Decode for DatagramType { 0x05 => Ok(Self::PayloadExt), 0x06 => Ok(Self::PayloadEndOfGroup), 0x07 => Ok(Self::PayloadExtEndOfGroup), + // Payload types with Priority Not Present (0x08-0x0F) + 0x08 => Ok(Self::ObjectIdPayloadNoPriority), + 0x09 => Ok(Self::ObjectIdPayloadExtNoPriority), + 0x0a => Ok(Self::ObjectIdPayloadEndOfGroupNoPriority), + 0x0b => Ok(Self::ObjectIdPayloadExtEndOfGroupNoPriority), + 0x0c => Ok(Self::PayloadNoPriority), + 0x0d => Ok(Self::PayloadExtNoPriority), + 0x0e => Ok(Self::PayloadEndOfGroupNoPriority), + 0x0f => Ok(Self::PayloadExtEndOfGroupNoPriority), + // Status types with Priority Present (0x20-0x25) 0x20 => Ok(Self::ObjectIdStatus), 0x21 => Ok(Self::ObjectIdStatusExt), + 0x24 => Ok(Self::Status), + 0x25 => Ok(Self::StatusExt), + // Status types with Priority Not Present (0x28-0x2D) + 0x28 => Ok(Self::ObjectIdStatusNoPriority), + 0x29 => Ok(Self::ObjectIdStatusExtNoPriority), + 0x2c => Ok(Self::StatusNoPriority), + 0x2d => Ok(Self::StatusExtNoPriority), _ => Err(DecodeError::InvalidDatagramType), } } @@ -56,9 +186,10 @@ pub struct Datagram { pub object_id: Option, /// Publisher priority, where **smaller** values are sent first. - pub publisher_priority: u8, + /// Optional when using NoPriority datagram types (0x08-0x0F, 0x28-0x2D). + pub publisher_priority: Option, - /// Optional extension headers if type is 0x1 (NoEndOfGroupWithExtensions) or 0x3 (EndofGroupWithExtensions) + /// Optional extension headers for types with extensions pub extension_headers: Option, /// The Object Status. @@ -75,47 +206,38 @@ impl Decode for Datagram { let group_id = u64::decode(r)?; // Decode Object Id if required - let object_id = match datagram_type { - DatagramType::ObjectIdPayload - | DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadEndOfGroup - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::ObjectIdStatus - | DatagramType::ObjectIdStatusExt => Some(u64::decode(r)?), - _ => None, + let object_id = if datagram_type.has_object_id() { + Some(u64::decode(r)?) + } else { + None }; - let publisher_priority = u8::decode(r)?; + // Decode Publisher Priority if required + let publisher_priority = if datagram_type.has_priority() { + Some(u8::decode(r)?) + } else { + None + }; // Decode Extension Headers if required - let extension_headers = match datagram_type { - DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::PayloadExt - | DatagramType::PayloadExtEndOfGroup - | DatagramType::ObjectIdStatusExt => Some(ExtensionHeaders::decode(r)?), - _ => None, + let extension_headers = if datagram_type.has_extensions() { + Some(ExtensionHeaders::decode(r)?) + } else { + None }; - // Decode Status if required - let status = match datagram_type { - DatagramType::ObjectIdStatus | DatagramType::ObjectIdStatusExt => { - Some(ObjectStatus::decode(r)?) - } - _ => None, + // Decode Status if required (for status datagram types) + let status = if datagram_type.is_status() { + Some(ObjectStatus::decode(r)?) + } else { + None }; - // Decode Payload if required - let payload = match datagram_type { - DatagramType::ObjectIdPayload - | DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadEndOfGroup - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::Payload - | DatagramType::PayloadExt - | DatagramType::PayloadEndOfGroup - | DatagramType::PayloadExtEndOfGroup => Some(r.copy_to_bytes(r.remaining())), - _ => None, + // Decode Payload if required (for payload datagram types) + let payload = if datagram_type.is_payload() { + Some(r.copy_to_bytes(r.remaining())) + } else { + None }; Ok(Self { @@ -138,70 +260,49 @@ impl Encode for Datagram { self.group_id.encode(w)?; // Encode Object Id if required - match self.datagram_type { - DatagramType::ObjectIdPayload - | DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadEndOfGroup - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::ObjectIdStatus - | DatagramType::ObjectIdStatusExt => { - if let Some(object_id) = &self.object_id { - object_id.encode(w)?; - } else { - return Err(EncodeError::MissingField("ObjectId".to_string())); - } + if self.datagram_type.has_object_id() { + if let Some(object_id) = &self.object_id { + object_id.encode(w)?; + } else { + return Err(EncodeError::MissingField("ObjectId".to_string())); } - _ => {} - }; + } - self.publisher_priority.encode(w)?; + // Encode Publisher Priority if required + if self.datagram_type.has_priority() { + if let Some(publisher_priority) = &self.publisher_priority { + publisher_priority.encode(w)?; + } else { + return Err(EncodeError::MissingField("PublisherPriority".to_string())); + } + } // Encode Extension Headers if required - match self.datagram_type { - DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::PayloadExt - | DatagramType::PayloadExtEndOfGroup - | DatagramType::ObjectIdStatusExt => { - if let Some(extension_headers) = &self.extension_headers { - extension_headers.encode(w)?; - } else { - return Err(EncodeError::MissingField("ExtensionHeaders".to_string())); - } + if self.datagram_type.has_extensions() { + if let Some(extension_headers) = &self.extension_headers { + extension_headers.encode(w)?; + } else { + return Err(EncodeError::MissingField("ExtensionHeaders".to_string())); } - _ => {} - }; + } - // Decode Status if required - match self.datagram_type { - DatagramType::ObjectIdStatus | DatagramType::ObjectIdStatusExt => { - if let Some(status) = &self.status { - status.encode(w)?; - } else { - return Err(EncodeError::MissingField("Status".to_string())); - } + // Encode Status if required (for status datagram types) + if self.datagram_type.is_status() { + if let Some(status) = &self.status { + status.encode(w)?; + } else { + return Err(EncodeError::MissingField("Status".to_string())); } - _ => {} } - // Decode Payload if required - match self.datagram_type { - DatagramType::ObjectIdPayload - | DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadEndOfGroup - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::Payload - | DatagramType::PayloadExt - | DatagramType::PayloadEndOfGroup - | DatagramType::PayloadExtEndOfGroup => { - if let Some(payload) = &self.payload { - Self::encode_remaining(w, payload.len())?; - w.put_slice(payload); - } else { - return Err(EncodeError::MissingField("Payload".to_string())); - } + // Encode Payload if required (for payload datagram types) + if self.datagram_type.is_payload() { + if let Some(payload) = &self.payload { + Self::encode_remaining(w, payload.len())?; + w.put_slice(payload); + } else { + return Err(EncodeError::MissingField("Payload".to_string())); } - _ => {} } Ok(()) @@ -293,7 +394,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -310,7 +411,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: Some(ext_hdrs.clone()), status: None, payload: Some(Bytes::from("payload")), @@ -327,7 +428,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -344,7 +445,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: Some(ext_hdrs.clone()), status: None, payload: Some(Bytes::from("payload")), @@ -361,7 +462,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: Some(ObjectStatus::EndOfTrack), payload: None, @@ -378,7 +479,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: Some(ext_hdrs.clone()), status: Some(ObjectStatus::EndOfTrack), payload: None, @@ -395,7 +496,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: None, - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -412,7 +513,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: None, - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: Some(ext_hdrs.clone()), status: None, payload: Some(Bytes::from("payload")), @@ -429,7 +530,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: None, - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -446,7 +547,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: None, - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: Some(ext_hdrs.clone()), status: None, payload: Some(Bytes::from("payload")), @@ -456,6 +557,40 @@ mod tests { assert_eq!(19, buf.len()); let decoded = Datagram::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); + + // DatagramType = ObjectIdPayloadNoPriority (no priority field) + let msg = Datagram { + datagram_type: DatagramType::ObjectIdPayloadNoPriority, + track_alias: 12, + group_id: 10, + object_id: Some(1234), + publisher_priority: None, + extension_headers: None, + status: None, + payload: Some(Bytes::from("payload")), + }; + msg.encode(&mut buf).unwrap(); + // Length should be: Type(1)+Alias(1)+GroupId(1)+ObjectId(2)+Payload(7) = 12 (no priority) + assert_eq!(12, buf.len()); + let decoded = Datagram::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); + + // DatagramType = PayloadNoPriority (no priority field, no object id) + let msg = Datagram { + datagram_type: DatagramType::PayloadNoPriority, + track_alias: 12, + group_id: 10, + object_id: None, + publisher_priority: None, + extension_headers: None, + status: None, + payload: Some(Bytes::from("payload")), + }; + msg.encode(&mut buf).unwrap(); + // Length should be: Type(1)+Alias(1)+GroupId(1)+Payload(7) = 10 (no priority, no object id) + assert_eq!(10, buf.len()); + let decoded = Datagram::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); } #[test] @@ -468,7 +603,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -482,7 +617,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -496,7 +631,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: Some(ObjectStatus::EndOfTrack), payload: None, @@ -510,7 +645,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: None, - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: None, @@ -524,7 +659,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: None, @@ -532,6 +667,18 @@ mod tests { let encoded = msg.encode(&mut buf); assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - // TODO SLG - add tests + // DatagramType = ObjectIdPayload - missing priority (priority is required for this type) + let msg = Datagram { + datagram_type: DatagramType::ObjectIdPayload, + track_alias: 12, + group_id: 10, + object_id: Some(1234), + publisher_priority: None, + extension_headers: None, + status: None, + payload: Some(Bytes::from("payload")), + }; + let encoded = msg.encode(&mut buf); + assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); } } diff --git a/moq-transport/src/data/extension_headers.rs b/moq-transport/src/data/extension_headers.rs index f6dba873..22191548 100644 --- a/moq-transport/src/data/extension_headers.rs +++ b/moq-transport/src/data/extension_headers.rs @@ -4,6 +4,8 @@ use std::fmt; /// A collection of KeyValuePair entries, where the length in bytes of key-value-pairs are encoded/decoded first. /// This structure is appropriate for Data plane extension headers. +/// +/// Per draft-16 Section 1.4.2, Key-Value-Pairs use delta-encoded Type fields. /// Since duplicate parameters are allowed for unknown extension headers, we don't do duplicate checking here. #[derive(Default, Clone, Eq, PartialEq)] pub struct ExtensionHeaders(pub Vec); @@ -44,7 +46,40 @@ impl ExtensionHeaders { } } +impl ExtensionHeaders { + /// Decode extension headers from remaining bytes (no length prefix). + /// Used for Track Extensions in PUBLISH where the length is implicit from the message. + pub fn decode_remaining_bytes(r: &mut R) -> Result { + if !r.has_remaining() { + return Ok(ExtensionHeaders::new()); + } + + let mut kvps = Vec::new(); + let mut prev_key: u64 = 0; + + while r.has_remaining() { + // Read delta type and reconstruct absolute key + let delta = u64::decode(r)?; + let key = prev_key.checked_add(delta).ok_or_else(|| { + log::error!( + "[ExtHdr] Delta type overflow: prev_key={}, delta={}", + prev_key, + delta + ); + DecodeError::BoundsExceeded(crate::coding::BoundsExceeded) + })?; + + let kvp = KeyValuePair::decode_value(key, r)?; + kvps.push(kvp); + prev_key = key; + } + + Ok(ExtensionHeaders(kvps)) + } +} + impl Decode for ExtensionHeaders { + /// Decode extension headers with delta-encoded Type fields (draft-16 Section 1.4.2). fn decode(r: &mut R) -> Result { // Read total byte length of the encoded kvps // Note: this is the difference between KeyValuePairs and ExtensionHeaders. @@ -65,9 +100,23 @@ impl Decode for ExtensionHeaders { let mut kvps_bytes = bytes::Bytes::from(buf); let mut kvps = Vec::new(); + let mut prev_key: u64 = 0; + while kvps_bytes.has_remaining() { - let kvp = KeyValuePair::decode(&mut kvps_bytes)?; + // Read delta type and reconstruct absolute key + let delta = u64::decode(&mut kvps_bytes)?; + let key = prev_key.checked_add(delta).ok_or_else(|| { + log::error!( + "[ExtHdr] Delta type overflow: prev_key={}, delta={}", + prev_key, + delta + ); + DecodeError::BoundsExceeded(crate::coding::BoundsExceeded) + })?; + + let kvp = KeyValuePair::decode_value(key, &mut kvps_bytes)?; kvps.push(kvp); + prev_key = key; } Ok(ExtensionHeaders(kvps)) @@ -75,14 +124,31 @@ impl Decode for ExtensionHeaders { } impl Encode for ExtensionHeaders { + /// Encode extension headers with delta-encoded Type fields (draft-16 Section 1.4.2). + /// Entries are sorted by key in ascending order before encoding. fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - // Encode all KeyValuePair entries into a temporary buffer to compute total byte length + // Sort by key for delta encoding + let mut sorted: Vec<&KeyValuePair> = self.0.iter().collect(); + sorted.sort_by_key(|kvp| kvp.key); + + // Encode all entries into a temporary buffer to compute total byte length let mut tmp = bytes::BytesMut::new(); - for kvp in &self.0 { - kvp.encode(&mut tmp)?; + let mut prev_key: u64 = 0; + for kvp in sorted { + let delta = kvp.key.checked_sub(prev_key).ok_or_else(|| { + log::error!( + "[ExtHdr] Keys not sortable: prev_key={}, current_key={}", + prev_key, + kvp.key + ); + EncodeError::InvalidValue + })?; + delta.encode(&mut tmp)?; + kvp.encode_value(&mut tmp)?; + prev_key = kvp.key; } - // Write total byte length (u64) followed by the encoded bytes + // Write total byte length followed by the encoded bytes (tmp.len() as u64).encode(w)?; w.put_slice(&tmp); @@ -109,9 +175,10 @@ mod tests { use bytes::BytesMut; #[test] - fn encode_decode_extension_headers() { + fn encode_decode_extension_headers_single() { let mut buf = BytesMut::new(); + // Single entry: key=1. Delta from 0 = 1. let mut ext_hdrs = ExtensionHeaders::new(); ext_hdrs.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); ext_hdrs.encode(&mut buf).unwrap(); @@ -119,21 +186,55 @@ mod tests { buf.to_vec(), vec![ 0x07, // 7 bytes total length - 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, // Key=1, Value=[1,2,3,4,5] + // Delta=1, length=5, data + 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, ] ); let decoded = ExtensionHeaders::decode(&mut buf).unwrap(); assert_eq!(decoded, ext_hdrs); + } + + #[test] + fn encode_decode_extension_headers_multiple() { + let mut buf = BytesMut::new(); + // Multiple entries inserted out of order — encoding sorts by key. + // Keys: 0 (even, int), 1 (odd, bytes), 100 (even, int) let mut ext_hdrs = ExtensionHeaders::new(); - ext_hdrs.set_intvalue(0, 0); // 2 bytes - ext_hdrs.set_intvalue(100, 100); // 4 bytes - ext_hdrs.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); // 1 byte key, 1 byte length, 5 bytes data = 7 bytes + ext_hdrs.set_intvalue(0, 0); + ext_hdrs.set_intvalue(100, 100); + ext_hdrs.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); ext_hdrs.encode(&mut buf).unwrap(); let buf_vec = buf.to_vec(); - // Validate the encoded length and the KeyValuePair's length. - assert_eq!(14, buf_vec.len()); // 14 bytes total (length + 3 kvps) - assert_eq!(13, buf_vec[0]); // 13 bytes for the 3 KeyValuePairs data + + #[rustfmt::skip] + let expected = vec![ + 0x0d, // 13 bytes total length for the KVP data + // Entry 1: key=0 (delta=0), even, int value=0 + 0x00, 0x00, + // Entry 2: key=1 (delta=1), odd, bytes len=5 + 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, + // Entry 3: key=100 (delta=99), even, int value=100 + 0x40, 0x63, 0x40, 0x64, + ]; + assert_eq!(buf_vec, expected); + + // Decode and verify — decoded entries will be in sorted order + let decoded = ExtensionHeaders::decode(&mut buf).unwrap(); + let mut expected_ext = ExtensionHeaders::new(); + expected_ext.set_intvalue(0, 0); + expected_ext.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); + expected_ext.set_intvalue(100, 100); + assert_eq!(decoded, expected_ext); + } + + #[test] + fn encode_decode_extension_headers_empty() { + let mut buf = BytesMut::new(); + + let ext_hdrs = ExtensionHeaders::new(); + ext_hdrs.encode(&mut buf).unwrap(); + assert_eq!(buf.to_vec(), vec![0x00]); // length=0 let decoded = ExtensionHeaders::decode(&mut buf).unwrap(); assert_eq!(decoded, ext_hdrs); } diff --git a/moq-transport/src/data/extension_types.rs b/moq-transport/src/data/extension_types.rs new file mode 100644 index 00000000..7c83de7a --- /dev/null +++ b/moq-transport/src/data/extension_types.rs @@ -0,0 +1,38 @@ +//! Known extension header type constants for the MOQT data plane. +//! +//! These extension headers can be attached to objects in subgroups, datagrams, and fetch streams. +//! See the MOQT specification for detailed semantics of each extension type. + +/// Immutable Extensions (0xB) +/// +/// A container extension header that wraps other extension headers that MUST NOT +/// be modified by relays or intermediaries. The contents of this extension header +/// should be preserved exactly as received when forwarding objects. +pub const IMMUTABLE_EXTENSIONS: u64 = 0xB; + +/// Prior Group ID Gap (0x3C) +/// +/// Indicates that one or more groups prior to this one are missing or unavailable. +/// The value is an integer indicating the number of missing prior groups. +/// This is used to signal discontinuities in the group sequence to subscribers. +pub const PRIOR_GROUP_ID_GAP: u64 = 0x3C; + +/// Prior Object ID Gap (0x3E) +/// +/// Indicates that one or more objects prior to this one within the same group/subgroup +/// are missing or unavailable. The value is an integer indicating the number of missing +/// prior objects. This is used to signal discontinuities in the object sequence. +pub const PRIOR_OBJECT_ID_GAP: u64 = 0x3E; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extension_type_values() { + // Verify the spec-defined values + assert_eq!(IMMUTABLE_EXTENSIONS, 0xB); + assert_eq!(PRIOR_GROUP_ID_GAP, 0x3C); + assert_eq!(PRIOR_OBJECT_ID_GAP, 0x3E); + } +} diff --git a/moq-transport/src/data/header.rs b/moq-transport/src/data/header.rs index b8274b4a..3991e759 100644 --- a/moq-transport/src/data/header.rs +++ b/moq-transport/src/data/header.rs @@ -6,6 +6,7 @@ use std::fmt; #[repr(u64)] #[derive(Copy, Debug, Clone, Eq, PartialEq)] pub enum StreamHeaderType { + // Priority Present variants (0x10-0x1D) SubgroupZeroId = 0x10, SubgroupZeroIdExt = 0x11, SubgroupFirstObjectId = 0x12, @@ -18,13 +19,27 @@ pub enum StreamHeaderType { SubgroupFirstObjectIdExtEndOfGroup = 0x1b, SubgroupIdEndOfGroup = 0x1c, SubgroupIdExtEndOfGroup = 0x1d, + // Priority Not Present variants (0x30-0x3D) + SubgroupZeroIdNoPriority = 0x30, + SubgroupZeroIdExtNoPriority = 0x31, + SubgroupFirstObjectIdNoPriority = 0x32, + SubgroupFirstObjectIdExtNoPriority = 0x33, + SubgroupIdNoPriority = 0x34, + SubgroupIdExtNoPriority = 0x35, + SubgroupZeroIdEndOfGroupNoPriority = 0x38, + SubgroupZeroIdExtEndOfGroupNoPriority = 0x39, + SubgroupFirstObjectIdEndOfGroupNoPriority = 0x3a, + SubgroupFirstObjectIdExtEndOfGroupNoPriority = 0x3b, + SubgroupIdEndOfGroupNoPriority = 0x3c, + SubgroupIdExtEndOfGroupNoPriority = 0x3d, + // Fetch Fetch = 0x5, } impl StreamHeaderType { pub fn is_subgroup(&self) -> bool { let header_type = *self as u64; - (0x10..=0x1d).contains(&header_type) + (0x10..=0x1d).contains(&header_type) || (0x30..=0x3d).contains(&header_type) } pub fn is_fetch(&self) -> bool { @@ -40,6 +55,12 @@ impl StreamHeaderType { | StreamHeaderType::SubgroupZeroIdExtEndOfGroup | StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroup | StreamHeaderType::SubgroupIdExtEndOfGroup + | StreamHeaderType::SubgroupZeroIdExtNoPriority + | StreamHeaderType::SubgroupFirstObjectIdExtNoPriority + | StreamHeaderType::SubgroupIdExtNoPriority + | StreamHeaderType::SubgroupZeroIdExtEndOfGroupNoPriority + | StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroupNoPriority + | StreamHeaderType::SubgroupIdExtEndOfGroupNoPriority | StreamHeaderType::Fetch ) } @@ -51,8 +72,60 @@ impl StreamHeaderType { | StreamHeaderType::SubgroupIdExt | StreamHeaderType::SubgroupIdEndOfGroup | StreamHeaderType::SubgroupIdExtEndOfGroup + | StreamHeaderType::SubgroupIdNoPriority + | StreamHeaderType::SubgroupIdExtNoPriority + | StreamHeaderType::SubgroupIdEndOfGroupNoPriority + | StreamHeaderType::SubgroupIdExtEndOfGroupNoPriority + ) + } + + pub fn has_priority(&self) -> bool { + let header_type = *self as u64; + // Priority Present variants are 0x10-0x1D + // Priority Not Present variants are 0x30-0x3D + (0x10..=0x1d).contains(&header_type) + } + + /// Returns true if this header type signals end-of-group when the stream ends. + /// For these types, the relay should write an EndOfGroup marker when the stream completes. + pub fn signals_end_of_group(&self) -> bool { + matches!( + *self, + StreamHeaderType::SubgroupZeroIdEndOfGroup + | StreamHeaderType::SubgroupZeroIdExtEndOfGroup + | StreamHeaderType::SubgroupFirstObjectIdEndOfGroup + | StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroup + | StreamHeaderType::SubgroupIdEndOfGroup + | StreamHeaderType::SubgroupIdExtEndOfGroup + | StreamHeaderType::SubgroupZeroIdEndOfGroupNoPriority + | StreamHeaderType::SubgroupZeroIdExtEndOfGroupNoPriority + | StreamHeaderType::SubgroupFirstObjectIdEndOfGroupNoPriority + | StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroupNoPriority + | StreamHeaderType::SubgroupIdEndOfGroupNoPriority + | StreamHeaderType::SubgroupIdExtEndOfGroupNoPriority ) } + + /// Returns the equivalent header type without extensions. + /// Used when forwarding streams where objects have empty extension headers. + pub fn without_extensions(&self) -> Self { + match *self { + StreamHeaderType::SubgroupZeroIdExt => StreamHeaderType::SubgroupZeroId, + StreamHeaderType::SubgroupFirstObjectIdExt => StreamHeaderType::SubgroupFirstObjectId, + StreamHeaderType::SubgroupIdExt => StreamHeaderType::SubgroupId, + StreamHeaderType::SubgroupZeroIdExtEndOfGroup => StreamHeaderType::SubgroupZeroIdEndOfGroup, + StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroup => StreamHeaderType::SubgroupFirstObjectIdEndOfGroup, + StreamHeaderType::SubgroupIdExtEndOfGroup => StreamHeaderType::SubgroupIdEndOfGroup, + StreamHeaderType::SubgroupZeroIdExtNoPriority => StreamHeaderType::SubgroupZeroIdNoPriority, + StreamHeaderType::SubgroupFirstObjectIdExtNoPriority => StreamHeaderType::SubgroupFirstObjectIdNoPriority, + StreamHeaderType::SubgroupIdExtNoPriority => StreamHeaderType::SubgroupIdNoPriority, + StreamHeaderType::SubgroupZeroIdExtEndOfGroupNoPriority => StreamHeaderType::SubgroupZeroIdEndOfGroupNoPriority, + StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroupNoPriority => StreamHeaderType::SubgroupFirstObjectIdEndOfGroupNoPriority, + StreamHeaderType::SubgroupIdExtEndOfGroupNoPriority => StreamHeaderType::SubgroupIdEndOfGroupNoPriority, + // Already non-Ext or Fetch + other => other, + } + } } impl Encode for StreamHeaderType { @@ -83,6 +156,7 @@ impl Decode for StreamHeaderType { ); let header_type = match type_value { + // Priority Present variants (0x10-0x1D) 0x10_u64 => Ok(Self::SubgroupZeroId), 0x11_u64 => Ok(Self::SubgroupZeroIdExt), 0x12_u64 => Ok(Self::SubgroupFirstObjectId), @@ -95,6 +169,20 @@ impl Decode for StreamHeaderType { 0x1b_u64 => Ok(Self::SubgroupFirstObjectIdExtEndOfGroup), 0x1c_u64 => Ok(Self::SubgroupIdEndOfGroup), 0x1d_u64 => Ok(Self::SubgroupIdExtEndOfGroup), + // Priority Not Present variants (0x30-0x3D) + 0x30_u64 => Ok(Self::SubgroupZeroIdNoPriority), + 0x31_u64 => Ok(Self::SubgroupZeroIdExtNoPriority), + 0x32_u64 => Ok(Self::SubgroupFirstObjectIdNoPriority), + 0x33_u64 => Ok(Self::SubgroupFirstObjectIdExtNoPriority), + 0x34_u64 => Ok(Self::SubgroupIdNoPriority), + 0x35_u64 => Ok(Self::SubgroupIdExtNoPriority), + 0x38_u64 => Ok(Self::SubgroupZeroIdEndOfGroupNoPriority), + 0x39_u64 => Ok(Self::SubgroupZeroIdExtEndOfGroupNoPriority), + 0x3a_u64 => Ok(Self::SubgroupFirstObjectIdEndOfGroupNoPriority), + 0x3b_u64 => Ok(Self::SubgroupFirstObjectIdExtEndOfGroupNoPriority), + 0x3c_u64 => Ok(Self::SubgroupIdEndOfGroupNoPriority), + 0x3d_u64 => Ok(Self::SubgroupIdExtEndOfGroupNoPriority), + // Fetch 0x05_u64 => Ok(Self::Fetch), _ => { log::error!( @@ -290,7 +378,31 @@ mod tests { track_alias: 10, group_id: 0, subgroup_id: Some(1), - publisher_priority: 100, + publisher_priority: Some(100), + }), + fetch_header: None, + }; + sh.encode(&mut buf).unwrap(); + let decoded = StreamHeader::decode(&mut buf).unwrap(); + assert_eq!(decoded, sh); + assert!(sh.header_type.is_subgroup()); + assert!(!sh.header_type.is_fetch()); + assert!(sh.header_type.has_subgroup_id()); + } + + #[test] + fn encode_decode_stream_header_no_priority() { + let mut buf = BytesMut::new(); + + // Test a NoPriority subgroup header type + let sh = StreamHeader { + header_type: StreamHeaderType::SubgroupIdNoPriority, + subgroup_header: Some(SubgroupHeader { + header_type: StreamHeaderType::SubgroupIdNoPriority, + track_alias: 10, + group_id: 0, + subgroup_id: Some(1), + publisher_priority: None, }), fetch_header: None, }; @@ -300,5 +412,6 @@ mod tests { assert!(sh.header_type.is_subgroup()); assert!(!sh.header_type.is_fetch()); assert!(sh.header_type.has_subgroup_id()); + assert!(!sh.header_type.has_priority()); } } diff --git a/moq-transport/src/data/mod.rs b/moq-transport/src/data/mod.rs index d76ba871..0d0025ab 100644 --- a/moq-transport/src/data/mod.rs +++ b/moq-transport/src/data/mod.rs @@ -1,5 +1,6 @@ mod datagram; mod extension_headers; +mod extension_types; mod fetch; mod header; mod object_status; @@ -7,6 +8,7 @@ mod subgroup; pub use datagram::*; pub use extension_headers::*; +pub use extension_types::*; pub use fetch::*; pub use header::*; pub use object_status::*; diff --git a/moq-transport/src/data/subgroup.rs b/moq-transport/src/data/subgroup.rs index 45e89e9f..9cfb1127 100644 --- a/moq-transport/src/data/subgroup.rs +++ b/moq-transport/src/data/subgroup.rs @@ -16,7 +16,8 @@ pub struct SubgroupHeader { pub subgroup_id: Option, /// Publisher priority, where **smaller** values are sent first. - pub publisher_priority: u8, + /// Optional when using NoPriority header types (0x30-0x3D). + pub publisher_priority: Option, } // Note: Not using the Decode trait, since we need to know the header_type to properly parse this, and it @@ -52,12 +53,20 @@ impl SubgroupHeader { } }; - let publisher_priority = u8::decode(r)?; - log::trace!( - "[DECODE] SubgroupHeader: publisher_priority={}, buffer_remaining={} bytes", - publisher_priority, - r.remaining() - ); + let publisher_priority = if header_type.has_priority() { + let priority = u8::decode(r)?; + log::trace!( + "[DECODE] SubgroupHeader: publisher_priority={}, buffer_remaining={} bytes", + priority, + r.remaining() + ); + Some(priority) + } else { + log::trace!( + "[DECODE] SubgroupHeader: publisher_priority=None (not present for NoPriority header type)" + ); + None + }; let result = Self { header_type, @@ -68,7 +77,7 @@ impl SubgroupHeader { }; log::debug!( - "[DECODE] SubgroupHeader complete: track_alias={}, group_id={}, subgroup_id={:?}, priority={}", + "[DECODE] SubgroupHeader complete: track_alias={}, group_id={}, subgroup_id={:?}, priority={:?}", result.track_alias, result.group_id, result.subgroup_id, @@ -82,7 +91,7 @@ impl SubgroupHeader { impl Encode for SubgroupHeader { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { log::trace!( - "[ENCODE] SubgroupHeader: starting encode - track_alias={}, group_id={}, subgroup_id={:?}, priority={}, header_type={:?}", + "[ENCODE] SubgroupHeader: starting encode - track_alias={}, group_id={}, subgroup_id={:?}, priority={:?}, header_type={:?}", self.track_alias, self.group_id, self.subgroup_id, @@ -125,11 +134,25 @@ impl Encode for SubgroupHeader { log::trace!("[ENCODE] SubgroupHeader: subgroup_id not encoded (not required for this header type)"); } - self.publisher_priority.encode(w)?; - log::trace!( - "[ENCODE] SubgroupHeader: encoded publisher_priority={}", - self.publisher_priority - ); + if self.header_type.has_priority() { + if let Some(publisher_priority) = self.publisher_priority { + publisher_priority.encode(w)?; + log::trace!( + "[ENCODE] SubgroupHeader: encoded publisher_priority={}", + publisher_priority + ); + } else { + log::error!( + "[ENCODE] SubgroupHeader: MISSING publisher_priority for header_type={:?}", + self.header_type + ); + return Err(EncodeError::MissingField("PublisherPriority".to_string())); + } + } else { + log::trace!( + "[ENCODE] SubgroupHeader: publisher_priority not encoded (NoPriority header type)" + ); + } let bytes_written = start_pos - w.remaining_mut(); log::debug!( diff --git a/moq-transport/src/message/dynamic_groups.rs b/moq-transport/src/message/dynamic_groups.rs new file mode 100644 index 00000000..4a5b7708 --- /dev/null +++ b/moq-transport/src/message/dynamic_groups.rs @@ -0,0 +1,131 @@ +//! Dynamic Groups support for MOQT. +//! +//! This module provides helper functions for working with Dynamic Groups parameters +//! as defined in the MOQT specification. Dynamic Groups allow subscribers to request +//! publishers to create new groups on demand. + +use crate::coding::KeyValuePairs; +use crate::message::ParameterType; + +/// Helper trait for Dynamic Groups parameter operations on KeyValuePairs. +pub trait DynamicGroupsExt { + /// Check if dynamic groups are enabled/supported + fn has_dynamic_groups(&self) -> bool; + + /// Get the dynamic groups value (if present) + fn get_dynamic_groups(&self) -> Option; + + /// Enable dynamic groups support + fn set_dynamic_groups(&mut self, value: u64); + + /// Check if a new group request is present + fn has_new_group_request(&self) -> bool; + + /// Get the new group request value (if present) + fn get_new_group_request(&self) -> Option; + + /// Request a new group from the publisher + fn set_new_group_request(&mut self, value: u64); +} + +impl DynamicGroupsExt for KeyValuePairs { + fn has_dynamic_groups(&self) -> bool { + self.has(ParameterType::DynamicGroups.into()) + } + + fn get_dynamic_groups(&self) -> Option { + self.get_intvalue(ParameterType::DynamicGroups.into()) + } + + fn set_dynamic_groups(&mut self, value: u64) { + self.set_intvalue(ParameterType::DynamicGroups.into(), value); + } + + fn has_new_group_request(&self) -> bool { + self.has(ParameterType::NewGroupRequest.into()) + } + + fn get_new_group_request(&self) -> Option { + self.get_intvalue(ParameterType::NewGroupRequest.into()) + } + + fn set_new_group_request(&mut self, value: u64) { + self.set_intvalue(ParameterType::NewGroupRequest.into(), value); + } +} + +/// Dynamic Groups configuration for a track +#[derive(Clone, Debug, Default)] +pub struct DynamicGroupsConfig { + /// Whether dynamic groups are enabled for this track + pub enabled: bool, + /// The current pending new group request (if any) + pub pending_request: Option, +} + +impl DynamicGroupsConfig { + /// Create a new configuration with dynamic groups disabled + pub fn new() -> Self { + Self::default() + } + + /// Create a new configuration with dynamic groups enabled + pub fn enabled() -> Self { + Self { + enabled: true, + pending_request: None, + } + } + + /// Request a new group with the given request ID + pub fn request_new_group(&mut self, request_id: u64) { + self.pending_request = Some(request_id); + } + + /// Clear the pending request (after it has been processed) + pub fn clear_pending_request(&mut self) { + self.pending_request = None; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dynamic_groups_ext() { + let mut params = KeyValuePairs::new(); + + // Initially no dynamic groups + assert!(!params.has_dynamic_groups()); + assert_eq!(params.get_dynamic_groups(), None); + + // Enable dynamic groups + params.set_dynamic_groups(1); + assert!(params.has_dynamic_groups()); + assert_eq!(params.get_dynamic_groups(), Some(1)); + + // New group request + assert!(!params.has_new_group_request()); + params.set_new_group_request(42); + assert!(params.has_new_group_request()); + assert_eq!(params.get_new_group_request(), Some(42)); + } + + #[test] + fn test_dynamic_groups_config() { + let config = DynamicGroupsConfig::new(); + assert!(!config.enabled); + assert!(config.pending_request.is_none()); + + let config = DynamicGroupsConfig::enabled(); + assert!(config.enabled); + + let mut config = DynamicGroupsConfig::enabled(); + config.request_new_group(123); + assert_eq!(config.pending_request, Some(123)); + + config.clear_pending_request(); + assert!(config.pending_request.is_none()); + } +} diff --git a/moq-transport/src/message/fetch_error.rs b/moq-transport/src/message/fetch_error.rs deleted file mode 100644 index b1acc55b..00000000 --- a/moq-transport/src/message/fetch_error.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct FetchError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for FetchError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for FetchError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/fetch_ok.rs b/moq-transport/src/message/fetch_ok.rs index c8db3c04..f3b721ed 100644 --- a/moq-transport/src/message/fetch_ok.rs +++ b/moq-transport/src/message/fetch_ok.rs @@ -1,4 +1,5 @@ use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; +use crate::data::ExtensionHeaders; use crate::message::GroupOrder; /// A publisher sends a FETCH_OK control message in response to successful fetches. @@ -18,6 +19,9 @@ pub struct FetchOk { /// Optional parameters pub params: KeyValuePairs, + + /// Track extensions + pub track_extensions: ExtensionHeaders, } impl Decode for FetchOk { @@ -33,6 +37,7 @@ impl Decode for FetchOk { let end_of_track = bool::decode(r)?; let end_location = Location::decode(r)?; let params = KeyValuePairs::decode(r)?; + let track_extensions = ExtensionHeaders::decode(r)?; Ok(Self { id, @@ -40,6 +45,7 @@ impl Decode for FetchOk { end_of_track, end_location, params, + track_extensions, }) } } @@ -57,6 +63,7 @@ impl Encode for FetchOk { self.end_of_track.encode(w)?; self.end_location.encode(w)?; self.params.encode(w)?; + self.track_extensions.encode(w)?; Ok(()) } @@ -81,6 +88,7 @@ mod tests { end_of_track: true, end_location: Location::new(2, 3), params: kvps.clone(), + track_extensions: Default::default(), }; msg.encode(&mut buf).unwrap(); let decoded = FetchOk::decode(&mut buf).unwrap(); @@ -97,6 +105,7 @@ mod tests { end_of_track: true, end_location: Location::new(2, 3), params: Default::default(), + track_extensions: Default::default(), }; let encoded = msg.encode(&mut buf); assert!(matches!(encoded.unwrap_err(), EncodeError::InvalidValue)); diff --git a/moq-transport/src/message/mod.rs b/moq-transport/src/message/mod.rs index 267368c7..246474a8 100644 --- a/moq-transport/src/message/mod.rs +++ b/moq-transport/src/message/mod.rs @@ -5,73 +5,65 @@ //! The only exception are OBJECT "messages", which are sent over dedicated QUIC streams. //! +mod dynamic_groups; mod fetch; mod fetch_cancel; -mod fetch_error; mod fetch_ok; mod fetch_type; mod filter_type; mod go_away; mod group_order; mod max_request_id; -mod pubilsh_namespace_done; +mod namespace; +mod parameters; mod publish; mod publish_done; -mod publish_error; mod publish_namespace; mod publish_namespace_cancel; -mod publish_namespace_error; -mod publish_namespace_ok; +mod publish_namespace_done; mod publish_ok; mod publisher; +mod request_error; +mod request_ok; mod requests_blocked; mod subscribe; -mod subscribe_error; mod subscribe_namespace; -mod subscribe_namespace_error; -mod subscribe_namespace_ok; mod subscribe_ok; mod subscribe_update; mod subscriber; mod track_status; -mod track_status_error; mod track_status_ok; mod unsubscribe; -mod unsubscribe_namespace; +pub use dynamic_groups::*; pub use fetch::*; pub use fetch_cancel::*; -pub use fetch_error::*; pub use fetch_ok::*; pub use fetch_type::*; pub use filter_type::*; pub use go_away::*; pub use group_order::*; pub use max_request_id::*; -pub use pubilsh_namespace_done::*; +pub use namespace::*; +pub use parameters::*; pub use publish::*; pub use publish_done::*; -pub use publish_error::*; pub use publish_namespace::*; pub use publish_namespace_cancel::*; -pub use publish_namespace_error::*; -pub use publish_namespace_ok::*; +pub use publish_namespace_done::*; pub use publish_ok::*; pub use publisher::*; +pub use request_error::*; +pub use request_ok::*; pub use requests_blocked::*; pub use subscribe::*; -pub use subscribe_error::*; pub use subscribe_namespace::*; -pub use subscribe_namespace_error::*; -pub use subscribe_namespace_ok::*; pub use subscribe_ok::*; pub use subscribe_update::*; pub use subscriber::*; pub use track_status::*; -pub use track_status_error::*; pub use track_status_ok::*; pub use unsubscribe::*; -pub use unsubscribe_namespace::*; use crate::coding::{Decode, DecodeError, Encode, EncodeError}; use std::fmt; @@ -89,13 +81,18 @@ macro_rules! message_types { impl Decode for Message { fn decode(r: &mut R) -> Result { let t = u64::decode(r)?; - let _len = u16::decode(r)?; + let len = u16::decode(r)? as usize; - // TODO: Check the length of the message. + // Read exactly len bytes into a sub-buffer to properly handle Track Extensions + if r.remaining() < len { + return Err(DecodeError::More(len - r.remaining())); + } + let payload = r.copy_to_bytes(len); + let mut payload_reader = std::io::Cursor::new(payload); match t { $($val => { - let msg = $name::decode(r)?; + let msg = $name::decode(&mut payload_reader)?; Ok(Self::$name(msg)) })* _ => Err(DecodeError::InvalidMessage(t)), @@ -185,40 +182,36 @@ message_types! { Unsubscribe = 0xa, // SUBSCRIBE family, sent by publisher SubscribeOk = 0x4, - SubscribeError = 0x5, // ANNOUNCE family, sent by publisher PublishNamespace = 0x6, PublishNamespaceDone = 0x9, // ANNOUNCE family, sent by subscriber - PublishNamespaceOk = 0x7, - PublishNamespaceError = 0x8, + RequestOk = 0x7, PublishNamespaceCancel = 0xc, + // NAMESPACE family, sent by relay to subscriber (draft-16) + Namespace = 0x8, + // TRACK_STATUS family, sent by subscriber TrackStatus = 0xd, // TRACK_STATUS family, sent by publisher TrackStatusOk = 0xe, - TrackStatusError = 0xf, // NAMESPACE family, sent by subscriber SubscribeNamespace = 0x11, - UnsubscribeNamespace = 0x14, - // NAMESPACE family, sent by publisher - SubscribeNamespaceOk = 0x12, - SubscribeNamespaceError = 0x13, // FETCH family, sent by subscriber Fetch = 0x16, FetchCancel = 0x17, // FETCH family, sent by publisher FetchOk = 0x18, - FetchError = 0x19, // PUBLISH family, sent by publisher Publish = 0x1d, PublishDone = 0xb, // PUBLISH family, sent by subscriber PublishOk = 0x1e, - PublishError = 0x1f, + + RequestError = 0x5, } diff --git a/moq-transport/src/message/namespace.rs b/moq-transport/src/message/namespace.rs new file mode 100644 index 00000000..978d1a1d --- /dev/null +++ b/moq-transport/src/message/namespace.rs @@ -0,0 +1,61 @@ +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackNamespace}; + +/// NAMESPACE message (draft-16) +/// +/// Sent by relay to subscriber to announce a namespace matching their SUBSCRIBE_NAMESPACE. +/// This is different from PUBLISH_NAMESPACE which is sent by publisher to relay. +/// +/// Wire format: 0x08 +#[derive(Clone, Debug)] +pub struct Namespace { + /// Request ID (from the SUBSCRIBE_NAMESPACE) + pub id: u64, + /// The namespace being announced + pub track_namespace: TrackNamespace, + /// Optional parameters + pub params: KeyValuePairs, +} + +impl Decode for Namespace { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; + let track_namespace = TrackNamespace::decode(r)?; + let params = KeyValuePairs::decode(r)?; + + Ok(Self { + id, + track_namespace, + params, + }) + } +} + +impl Encode for Namespace { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; + self.track_namespace.encode(w)?; + self.params.encode(w)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_namespace_encode_decode() { + let msg = Namespace { + id: 42, + track_namespace: TrackNamespace::from_utf8_path("live/room1"), + params: KeyValuePairs::new(), + }; + + let mut buf = Vec::new(); + msg.encode(&mut buf).unwrap(); + + let decoded = Namespace::decode(&mut buf.as_slice()).unwrap(); + assert_eq!(decoded.id, 42); + assert_eq!(decoded.track_namespace.to_utf8_path(), "live/room1"); + } +} diff --git a/moq-transport/src/message/parameters.rs b/moq-transport/src/message/parameters.rs new file mode 100644 index 00000000..b61d9e37 --- /dev/null +++ b/moq-transport/src/message/parameters.rs @@ -0,0 +1,47 @@ +/// Version-Specific Message Parameter Types +/// Used in SUBSCRIBE, SUBSCRIBE_OK, PUBLISH, FETCH, REQUEST_UPDATE, etc. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[repr(u64)] +pub enum ParameterType { + /// Used in: REQUEST_OK, PUBLISH, PUBLISH_OK, SUBSCRIBE, SUBSCRIBE_OK, REQUEST_UPDATE + DeliveryTimeout = 0x02, + /// Used in: CLIENT_SETUP, SERVER_SETUP, PUBLISH, SUBSCRIBE, REQUEST_UPDATE, + /// SUBSCRIBE_NAMESPACE, PUBLISH_NAMESPACE, TRACK_STATUS, FETCH + AuthorizationToken = 0x03, + /// Used in: PUBLISH, SUBSCRIBE_OK, FETCH_OK, REQUEST_OK + MaxCacheDuration = 0x04, + /// Used in: SUBSCRIBE_OK, PUBLISH, PUBLISH_OK + Expires = 0x08, + /// Used in: SUBSCRIBE_OK, PUBLISH, REQUEST_OK + LargestObject = 0x09, + /// Used in: SUBSCRIBE_OK, PUBLISH + PublisherPriority = 0x0E, + /// Used in: SUBSCRIBE, REQUEST_UPDATE, PUBLISH, PUBLISH_OK, SUBSCRIBE_NAMESPACE + Forward = 0x10, + /// Used in: SUBSCRIBE, FETCH, REQUEST_UPDATE, PUBLISH_OK + SubscriberPriority = 0x20, + /// Used in: SUBSCRIBE, PUBLISH_OK, REQUEST_UPDATE (renamed to SubscriptionLocationFilter per PR #1518) + SubscriptionFilter = 0x21, + /// Used in: SUBSCRIBE, SUBSCRIBE_OK, REQUEST_OK, PUBLISH, PUBLISH_OK, FETCH + GroupOrder = 0x22, + /// Used in: SUBSCRIBE, FETCH - Filter by subgroup ID ranges (PR #1518) + SubgroupFilter = 0x25, + /// Used in: SUBSCRIBE, FETCH - Filter by object ID ranges (PR #1518) + ObjectFilter = 0x26, + /// Used in: SUBSCRIBE, FETCH - Filter by priority ranges (PR #1518) + PriorityFilter = 0x27, + /// Used in: SUBSCRIBE, FETCH - Filter by property value ranges (PR #1518) + PropertyFilter = 0x28, + /// Used in: SUBSCRIBE_NAMESPACE - Track filter for top-N selection (PR #1518) + TrackFilter = 0x29, + /// Used in: PUBLISH, SUBSCRIBE_OK + DynamicGroups = 0x30, + /// Used in: PUBLISH_OK, SUBSCRIBE, REQUEST_UPDATE + NewGroupRequest = 0x32, +} + +impl From for u64 { + fn from(value: ParameterType) -> Self { + value as u64 + } +} diff --git a/moq-transport/src/message/publish.rs b/moq-transport/src/message/publish.rs index feea3639..467b51cd 100644 --- a/moq-transport/src/message/publish.rs +++ b/moq-transport/src/message/publish.rs @@ -1,9 +1,10 @@ -use crate::coding::{ - Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location, TrackNamespace, -}; -use crate::message::GroupOrder; +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackNamespace}; +use crate::data::ExtensionHeaders; /// Sent by publisher to initiate a subscription to a track. +/// +/// Draft-16: Fields like group_order, content_exists, largest_location, forward +/// have been moved to Parameters (Section 9.2.2). #[derive(Clone, Debug, Eq, PartialEq)] pub struct Publish { /// The publish request ID @@ -14,14 +15,11 @@ pub struct Publish { pub track_name: String, // TODO SLG - consider making a FullTrackName base struct (total size limit of 4096) pub track_alias: u64, - pub group_order: GroupOrder, - pub content_exists: bool, - // The largest object available for this track, if content exists. - pub largest_location: Option, - pub forward: bool, - - /// Optional parameters + /// Optional parameters (may contain Forward, GroupOrder, LargestObject, PublisherPriority, etc.) pub params: KeyValuePairs, + + /// Track extensions + pub track_extensions: ExtensionHeaders, } impl Decode for Publish { @@ -32,31 +30,18 @@ impl Decode for Publish { let track_name = String::decode(r)?; let track_alias = u64::decode(r)?; - let group_order = GroupOrder::decode(r)?; - // GroupOrder enum has Publisher in it, but it's not allowed to be used in this - // publish message, so validate it now so we can return a protocol error. - if group_order == GroupOrder::Publisher { - return Err(DecodeError::InvalidGroupOrder); - } - let content_exists = bool::decode(r)?; - let largest_location = match content_exists { - true => Some(Location::decode(r)?), - false => None, - }; - let forward = bool::decode(r)?; - let params = KeyValuePairs::decode(r)?; + // Track Extensions use remaining bytes (no length prefix per draft-16) + let track_extensions = ExtensionHeaders::decode_remaining_bytes(r)?; + Ok(Self { id, track_namespace, track_name, track_alias, - group_order, - content_exists, - largest_location, - forward, params, + track_extensions, }) } } @@ -69,22 +54,8 @@ impl Encode for Publish { self.track_name.encode(w)?; self.track_alias.encode(w)?; - // GroupOrder enum has Publisher in it, but it's not allowed to be used in this - // publish message. - if self.group_order == GroupOrder::Publisher { - return Err(EncodeError::InvalidValue); - } - self.group_order.encode(w)?; - self.content_exists.encode(w)?; - if self.content_exists { - if let Some(largest) = &self.largest_location { - largest.encode(w)?; - } else { - return Err(EncodeError::MissingField("LargestLocation".to_string())); - } - } - self.forward.encode(w)?; self.params.encode(w)?; + self.track_extensions.encode(w)?; Ok(()) } @@ -99,37 +70,16 @@ mod tests { fn encode_decode() { let mut buf = BytesMut::new(); - // One parameter for testing let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(123, vec![0x00, 0x01, 0x02, 0x03]); - // Content exists = true - let msg = Publish { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - track_alias: 212, - group_order: GroupOrder::Ascending, - content_exists: true, - largest_location: Some(Location::new(2, 3)), - forward: true, - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = Publish::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - - // Content exists = false let msg = Publish { id: 12345, track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), track_name: "audiotrack".to_string(), track_alias: 212, - group_order: GroupOrder::Ascending, - content_exists: false, - largest_location: None, - forward: true, params: kvps.clone(), + track_extensions: Default::default(), }; msg.encode(&mut buf).unwrap(); let decoded = Publish::decode(&mut buf).unwrap(); @@ -137,7 +87,7 @@ mod tests { } #[test] - fn encode_missing_fields() { + fn encode_decode_no_params() { let mut buf = BytesMut::new(); let msg = Publish { @@ -145,32 +95,11 @@ mod tests { track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), track_name: "audiotrack".to_string(), track_alias: 212, - group_order: GroupOrder::Ascending, - content_exists: true, - largest_location: None, - forward: true, params: Default::default(), + track_extensions: Default::default(), }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - } - - #[test] - fn encode_bad_group_order() { - let mut buf = BytesMut::new(); - - let msg = Publish { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - track_alias: 212, - group_order: GroupOrder::Publisher, - content_exists: false, - largest_location: None, - forward: true, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::InvalidValue)); + msg.encode(&mut buf).unwrap(); + let decoded = Publish::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); } } diff --git a/moq-transport/src/message/publish_error.rs b/moq-transport/src/message/publish_error.rs deleted file mode 100644 index f8cc02b9..00000000 --- a/moq-transport/src/message/publish_error.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct PublishError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for PublishError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for PublishError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/pubilsh_namespace_done.rs b/moq-transport/src/message/publish_namespace_done.rs similarity index 100% rename from moq-transport/src/message/pubilsh_namespace_done.rs rename to moq-transport/src/message/publish_namespace_done.rs diff --git a/moq-transport/src/message/publish_namespace_error.rs b/moq-transport/src/message/publish_namespace_error.rs deleted file mode 100644 index 8a606621..00000000 --- a/moq-transport/src/message/publish_namespace_error.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an PUBLISH_NAMESPACE. -#[derive(Clone, Debug)] -pub struct PublishNamespaceError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for PublishNamespaceError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for PublishNamespaceError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/publish_namespace_ok.rs b/moq-transport/src/message/publish_namespace_ok.rs deleted file mode 100644 index 9025f03f..00000000 --- a/moq-transport/src/message/publish_namespace_ok.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError}; - -/// Sent by the subscriber to accept a PUBLISH_NAMESPACE. -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct PublishNamespaceOk { - /// The request ID of the PUBLISH_NAMESPACE this message is replying to. - pub id: u64, -} - -impl Decode for PublishNamespaceOk { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - Ok(Self { id }) - } -} - -impl Encode for PublishNamespaceOk { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::BytesMut; - - #[test] - fn encode_decode() { - let mut buf = BytesMut::new(); - - let msg = PublishNamespaceOk { id: 12345 }; - msg.encode(&mut buf).unwrap(); - let decoded = PublishNamespaceOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - } -} diff --git a/moq-transport/src/message/publish_ok.rs b/moq-transport/src/message/publish_ok.rs index 5564c7ac..e376c89b 100644 --- a/moq-transport/src/message/publish_ok.rs +++ b/moq-transport/src/message/publish_ok.rs @@ -1,110 +1,30 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; -use crate::message::FilterType; -use crate::message::GroupOrder; +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; -/// Sent by the subscriber to request all future objects for the given track. +/// Sent by the subscriber to acknowledge a PUBLISH message and establish a subscription. /// -/// Objects will use the provided ID instead of the full track name, to save bytes. +/// Draft-16: All subscription properties (forward, subscriber_priority, group_order, +/// filter_type, etc.) are now in Parameters (Section 9.2.2). #[derive(Clone, Debug, Eq, PartialEq)] pub struct PublishOk { /// The request ID of the Publish this message is replying to. pub id: u64, - /// Forward Flag - pub forward: bool, - - /// Subscriber Priority - pub subscriber_priority: u8, - - /// The order the subscription will be delivered in - pub group_order: GroupOrder, - - /// Filter type - pub filter_type: FilterType, - - /// The starting location for this subscription. Only present for "AbsoluteStart" and "AbsoluteRange" filter types. - pub start_location: Option, - /// End group id, inclusive, for the subscription, if applicable. Only present for "AbsoluteRange" filter type. - pub end_group_id: Option, - - /// Optional parameters + /// Parameters (may contain Forward, SubscriberPriority, GroupOrder, SubscriptionFilter, etc.) pub params: KeyValuePairs, } impl Decode for PublishOk { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; - - let forward = bool::decode(r)?; - let subscriber_priority = u8::decode(r)?; - let group_order = GroupOrder::decode(r)?; - - let filter_type = FilterType::decode(r)?; - let start_location: Option; - let end_group_id: Option; - match filter_type { - FilterType::AbsoluteStart => { - start_location = Some(Location::decode(r)?); - end_group_id = None; - } - FilterType::AbsoluteRange => { - start_location = Some(Location::decode(r)?); - end_group_id = Some(u64::decode(r)?); - } - _ => { - start_location = None; - end_group_id = None; - } - } - let params = KeyValuePairs::decode(r)?; - Ok(Self { - id, - forward, - subscriber_priority, - group_order, - filter_type, - start_location, - end_group_id, - params, - }) + Ok(Self { id, params }) } } impl Encode for PublishOk { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; - - self.forward.encode(w)?; - self.subscriber_priority.encode(w)?; - self.group_order.encode(w)?; - - self.filter_type.encode(w)?; - match self.filter_type { - FilterType::AbsoluteStart => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("StartLocation".to_string())); - } - // Just ignore end_group_id if it happens to be set - } - FilterType::AbsoluteRange => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("StartLocation".to_string())); - } - if let Some(end) = self.end_group_id { - end.encode(w)?; - } else { - return Err(EncodeError::MissingField("EndGroupId".to_string())); - } - } - _ => {} - } - self.params.encode(w)?; Ok(()) @@ -120,49 +40,11 @@ mod tests { fn encode_decode() { let mut buf = BytesMut::new(); - // One parameter for testing let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(123, vec![0x00, 0x01, 0x02, 0x03]); - // FilterType = NextGroupStart - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::NextGroupStart, - start_location: None, - end_group_id: None, - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = PublishOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - - // FilterType = AbsoluteStart - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteStart, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = PublishOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - - // FilterType = AbsoluteRange let msg = PublishOk { id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: Some(23456), params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); @@ -171,49 +53,15 @@ mod tests { } #[test] - fn encode_missing_fields() { + fn encode_decode_no_params() { let mut buf = BytesMut::new(); - // FilterType = AbsoluteStart - missing start_location - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteStart, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - - // FilterType = AbsoluteRange - missing start_location - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteRange, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - - // FilterType = AbsoluteRange - missing end_group_id let msg = PublishOk { id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, params: Default::default(), }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); + msg.encode(&mut buf).unwrap(); + let decoded = PublishOk::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); } } diff --git a/moq-transport/src/message/publisher.rs b/moq-transport/src/message/publisher.rs index 6cdf0750..61700289 100644 --- a/moq-transport/src/message/publisher.rs +++ b/moq-transport/src/message/publisher.rs @@ -48,14 +48,12 @@ macro_rules! publisher_msgs { publisher_msgs! { PublishNamespace, PublishNamespaceDone, + Namespace, Publish, PublishDone, SubscribeOk, - SubscribeError, TrackStatusOk, - TrackStatusError, FetchOk, - FetchError, - SubscribeNamespaceOk, - SubscribeNamespaceError, + RequestOk, + RequestError, } diff --git a/moq-transport/src/message/request_error.rs b/moq-transport/src/message/request_error.rs new file mode 100644 index 00000000..fce02c6f --- /dev/null +++ b/moq-transport/src/message/request_error.rs @@ -0,0 +1,87 @@ +use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; + +/// REQUEST_ERROR message (draft-16 Section 9.8). +/// +/// Sent in response to any request (SUBSCRIBE, FETCH, PUBLISH, etc.) to indicate failure. +#[derive(Clone, Debug)] +pub struct RequestError { + pub id: u64, + + /// An error code identifying the failure reason. + pub error_code: u64, + + /// Minimum time in milliseconds before the request SHOULD be sent again, plus one. + /// A value of 0 means the request SHOULD NOT be retried. + /// A value of 1 means the request can be retried immediately. + pub retry_interval: u64, + + /// An optional, human-readable reason. + pub reason_phrase: ReasonPhrase, +} + +impl Decode for RequestError { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; + let error_code = u64::decode(r)?; + let retry_interval = u64::decode(r)?; + let reason_phrase = ReasonPhrase::decode(r)?; + + Ok(Self { + id, + error_code, + retry_interval, + reason_phrase, + }) + } +} + +impl Encode for RequestError { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; + self.error_code.encode(w)?; + self.retry_interval.encode(w)?; + self.reason_phrase.encode(w)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BytesMut; + + #[test] + fn encode_decode() { + let mut buf = BytesMut::new(); + + let msg = RequestError { + id: 42, + error_code: 0x1, + retry_interval: 5000, + reason_phrase: ReasonPhrase("unauthorized".to_string()), + }; + msg.encode(&mut buf).unwrap(); + let decoded = RequestError::decode(&mut buf).unwrap(); + assert_eq!(decoded.id, msg.id); + assert_eq!(decoded.error_code, msg.error_code); + assert_eq!(decoded.retry_interval, msg.retry_interval); + } + + #[test] + fn encode_decode_no_retry() { + let mut buf = BytesMut::new(); + + let msg = RequestError { + id: 10, + error_code: 0x0, + retry_interval: 0, + reason_phrase: ReasonPhrase("internal error".to_string()), + }; + msg.encode(&mut buf).unwrap(); + let decoded = RequestError::decode(&mut buf).unwrap(); + assert_eq!(decoded.id, msg.id); + assert_eq!(decoded.error_code, msg.error_code); + assert_eq!(decoded.retry_interval, 0); + } +} diff --git a/moq-transport/src/message/request_ok.rs b/moq-transport/src/message/request_ok.rs new file mode 100644 index 00000000..9ceb8879 --- /dev/null +++ b/moq-transport/src/message/request_ok.rs @@ -0,0 +1,45 @@ +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; + +/// Reqeust Ok +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct RequestOk { + /// The SubscribeNamespace/PublishNamespace request ID this message is replying to. + pub id: u64, + + /// Optional parameters + pub params: KeyValuePairs, +} + +impl Decode for RequestOk { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; + let params = KeyValuePairs::decode(r)?; + Ok(Self { id, params }) + } +} + +impl Encode for RequestOk { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; + self.params.encode(w) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BytesMut; + + #[test] + fn encode_decode() { + let mut buf = BytesMut::new(); + + let msg = RequestOk { + id: 12345, + params: KeyValuePairs::new(), + }; + msg.encode(&mut buf).unwrap(); + let decoded = RequestOk::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); + } +} diff --git a/moq-transport/src/message/subscribe.rs b/moq-transport/src/message/subscribe.rs index 828ccb92..e2a3d0bd 100644 --- a/moq-transport/src/message/subscribe.rs +++ b/moq-transport/src/message/subscribe.rs @@ -1,8 +1,4 @@ -use crate::coding::{ - Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location, TrackNamespace, -}; -use crate::message::FilterType; -use crate::message::GroupOrder; +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackNamespace}; /// Sent by the subscriber to request all future objects for the given track. /// @@ -16,22 +12,9 @@ pub struct Subscribe { pub track_namespace: TrackNamespace, pub track_name: String, // TODO SLG - consider making a FullTrackName base struct (total size limit of 4096) - /// Subscriber Priority - pub subscriber_priority: u8, - pub group_order: GroupOrder, - - /// Forward Flag - pub forward: bool, - - /// Filter type - pub filter_type: FilterType, - - /// The starting location for this subscription. Only present for "AbsoluteStart" and "AbsoluteRange" filter types. - pub start_location: Option, - /// End group id, inclusive, for the subscription, if applicable. Only present for "AbsoluteRange" filter type. - pub end_group_id: Option, - /// Optional parameters + /// NOTE(itzmanish): since the forward and other fields are moved to parameters + /// we need to validate them on publisher logic pub params: KeyValuePairs, } @@ -42,41 +25,12 @@ impl Decode for Subscribe { let track_namespace = TrackNamespace::decode(r)?; let track_name = String::decode(r)?; - let subscriber_priority = u8::decode(r)?; - let group_order = GroupOrder::decode(r)?; - - let forward = bool::decode(r)?; - - let filter_type = FilterType::decode(r)?; - let start_location: Option; - let end_group_id: Option; - match filter_type { - FilterType::AbsoluteStart => { - start_location = Some(Location::decode(r)?); - end_group_id = None; - } - FilterType::AbsoluteRange => { - start_location = Some(Location::decode(r)?); - end_group_id = Some(u64::decode(r)?); - } - _ => { - start_location = None; - end_group_id = None; - } - } - let params = KeyValuePairs::decode(r)?; Ok(Self { id, track_namespace, track_name, - subscriber_priority, - group_order, - forward, - filter_type, - start_location, - end_group_id, params, }) } @@ -88,37 +42,6 @@ impl Encode for Subscribe { self.track_namespace.encode(w)?; self.track_name.encode(w)?; - - self.subscriber_priority.encode(w)?; - self.group_order.encode(w)?; - - self.forward.encode(w)?; - - self.filter_type.encode(w)?; - match self.filter_type { - FilterType::AbsoluteStart => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("StartLocation".to_string())); - } - // Just ignore end_group_id if it happens to be set - } - FilterType::AbsoluteRange => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("StartLocation".to_string())); - } - if let Some(end) = self.end_group_id { - end.encode(w)?; - } else { - return Err(EncodeError::MissingField("EndGroupId".to_string())); - } - } - _ => {} - } - self.params.encode(w)?; Ok(()) @@ -143,12 +66,6 @@ mod tests { id: 12345, track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::NextGroupStart, - start_location: None, - end_group_id: None, params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); @@ -160,12 +77,6 @@ mod tests { id: 12345, track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteStart, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); @@ -177,69 +88,10 @@ mod tests { id: 12345, track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: Some(23456), params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); let decoded = Subscribe::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); } - - #[test] - fn encode_missing_fields() { - let mut buf = BytesMut::new(); - - // FilterType = AbsoluteStart - missing start_location - let msg = Subscribe { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteStart, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - - // FilterType = AbsoluteRange - missing start_location - let msg = Subscribe { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteRange, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - - // FilterType = AbsoluteRange - missing end_group_id - let msg = Subscribe { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - } } diff --git a/moq-transport/src/message/subscribe_error.rs b/moq-transport/src/message/subscribe_error.rs deleted file mode 100644 index 7481a4bf..00000000 --- a/moq-transport/src/message/subscribe_error.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct SubscribeError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for SubscribeError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for SubscribeError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/subscribe_namespace.rs b/moq-transport/src/message/subscribe_namespace.rs index ba292d69..92662036 100644 --- a/moq-transport/src/message/subscribe_namespace.rs +++ b/moq-transport/src/message/subscribe_namespace.rs @@ -9,19 +9,36 @@ pub struct SubscribeNamespace { /// The track namespace prefix pub track_namespace_prefix: TrackNamespace, + /// The Forward value that new subscriptions resulting from this SUBSCRIBE_NAMESPACE will have + pub forward: u8, + /// Optional parameters pub params: KeyValuePairs, } +impl SubscribeNamespace { + /// Creates a new SubscribeNamespace message. + pub fn new(id: u64, track_namespace_prefix: TrackNamespace, forward: u8) -> Self { + Self { + id, + track_namespace_prefix, + forward, + params: KeyValuePairs::new(), + } + } +} + impl Decode for SubscribeNamespace { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; let track_namespace_prefix = TrackNamespace::decode(r)?; + let forward = u8::decode(r)?; let params = KeyValuePairs::decode(r)?; Ok(Self { id, track_namespace_prefix, + forward, params, }) } @@ -31,6 +48,7 @@ impl Encode for SubscribeNamespace { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; self.track_namespace_prefix.encode(w)?; + self.forward.encode(w)?; self.params.encode(w)?; Ok(()) @@ -52,11 +70,14 @@ mod tests { let msg = SubscribeNamespace { id: 12345, + forward: 0, track_namespace_prefix: TrackNamespace::from_utf8_path("path/prefix"), params: kvps, }; msg.encode(&mut buf).unwrap(); let decoded = SubscribeNamespace::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); + assert_eq!(decoded.id, msg.id); + assert_eq!(decoded.forward, msg.forward); + assert_eq!(decoded.track_namespace_prefix, msg.track_namespace_prefix); } } diff --git a/moq-transport/src/message/subscribe_namespace_error.rs b/moq-transport/src/message/subscribe_namespace_error.rs deleted file mode 100644 index a5d99d0d..00000000 --- a/moq-transport/src/message/subscribe_namespace_error.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct SubscribeNamespaceError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for SubscribeNamespaceError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for SubscribeNamespaceError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/subscribe_namespace_ok.rs b/moq-transport/src/message/subscribe_namespace_ok.rs deleted file mode 100644 index 2e2a968d..00000000 --- a/moq-transport/src/message/subscribe_namespace_ok.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError}; - -/// Subscribe Namespace Ok -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct SubscribeNamespaceOk { - /// The SubscribeNamespace request ID this message is replying to. - pub id: u64, -} - -impl Decode for SubscribeNamespaceOk { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - Ok(Self { id }) - } -} - -impl Encode for SubscribeNamespaceOk { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::BytesMut; - - #[test] - fn encode_decode() { - let mut buf = BytesMut::new(); - - let msg = SubscribeNamespaceOk { id: 12345 }; - msg.encode(&mut buf).unwrap(); - let decoded = SubscribeNamespaceOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - } -} diff --git a/moq-transport/src/message/subscribe_ok.rs b/moq-transport/src/message/subscribe_ok.rs index 97bb3aa7..c87f952b 100644 --- a/moq-transport/src/message/subscribe_ok.rs +++ b/moq-transport/src/message/subscribe_ok.rs @@ -1,5 +1,4 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; -use crate::message::GroupOrder; +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackExtensions}; /// Sent by the publisher to accept a Subscribe. #[derive(Clone, Debug, Eq, PartialEq)] @@ -10,42 +9,26 @@ pub struct SubscribeOk { /// The identifier used for this track in Subgroups or Datagrams. pub track_alias: u64, - /// The time in milliseconds after which the subscription is not longer valid. - pub expires: u64, - - /// Order groups will be delivered in - pub group_order: GroupOrder, - - /// If content_exists, then largest_location is the location of the largest - /// object available for this track - pub content_exists: bool, - pub largest_location: Option, // Only provided if content_exists is 1/true - - /// Subscribe Parameters + /// Subscribe Parameters (has count prefix per spec) pub params: KeyValuePairs, + + /// Track extensions (NO prefix per draft-16 Section 9.10 - reads until end of message) + pub track_extensions: TrackExtensions, } impl Decode for SubscribeOk { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; let track_alias = u64::decode(r)?; - let expires = u64::decode(r)?; - let group_order = GroupOrder::decode(r)?; - let content_exists = bool::decode(r)?; - let largest_location = match content_exists { - true => Some(Location::decode(r)?), - false => None, - }; let params = KeyValuePairs::decode(r)?; + // Track extensions have NO prefix - read until end of message + let track_extensions = TrackExtensions::decode(r)?; Ok(Self { id, track_alias, - expires, - group_order, - content_exists, - largest_location, params, + track_extensions, }) } } @@ -54,17 +37,8 @@ impl Encode for SubscribeOk { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; self.track_alias.encode(w)?; - self.expires.encode(w)?; - self.group_order.encode(w)?; - self.content_exists.encode(w)?; - if self.content_exists { - if let Some(largest) = &self.largest_location { - largest.encode(w)?; - } else { - return Err(EncodeError::MissingField("LargestLocation".to_string())); - } - } self.params.encode(w)?; + self.track_extensions.encode(w)?; Ok(()) } @@ -83,14 +57,15 @@ mod tests { let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(123, vec![0x00, 0x01, 0x02, 0x03]); + // Track extensions (no prefix) + let mut ext = TrackExtensions::new(); + ext.set_intvalue(2, 42); + let msg = SubscribeOk { id: 12345, track_alias: 100, - expires: 3600, - group_order: GroupOrder::Publisher, - content_exists: true, - largest_location: Some(Location::new(2, 3)), - params: kvps.clone(), + params: kvps, + track_extensions: ext, }; msg.encode(&mut buf).unwrap(); let decoded = SubscribeOk::decode(&mut buf).unwrap(); @@ -98,19 +73,22 @@ mod tests { } #[test] - fn encode_missing_fields() { + fn encode_decode_empty_extensions() { let mut buf = BytesMut::new(); let msg = SubscribeOk { - id: 12345, - track_alias: 100, - expires: 3600, - group_order: GroupOrder::Publisher, - content_exists: true, - largest_location: None, - params: Default::default(), + id: 0, + track_alias: 0, + params: KeyValuePairs::new(), + track_extensions: TrackExtensions::new(), }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); + msg.encode(&mut buf).unwrap(); + // Expected: id=0 (1 byte), track_alias=0 (1 byte), params_count=0 (1 byte), NO track_extensions bytes + assert_eq!(buf.to_vec(), vec![0x00, 0x00, 0x00]); + let decoded = SubscribeOk::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); } + + // Note: encode_missing_fields test removed — content_exists was removed + // from the struct in draft-16; no fields to validate at encode time. } diff --git a/moq-transport/src/message/subscribe_update.rs b/moq-transport/src/message/subscribe_update.rs index 3bf20e23..895378d9 100644 --- a/moq-transport/src/message/subscribe_update.rs +++ b/moq-transport/src/message/subscribe_update.rs @@ -1,53 +1,31 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; -/// Sent by the subscriber to request all future objects for the given track. +/// REQUEST_UPDATE message (draft-16 Section 9.11). /// -/// Objects will use the provided ID instead of the full track name, to save bytes. +/// Sent to modify an existing request (SUBSCRIBE, PUBLISH, FETCH, etc.). +/// Parameters previously set that are not present in the update remain unchanged. #[derive(Clone, Debug, Eq, PartialEq)] pub struct SubscribeUpdate { - /// The request ID of this request + /// The request ID of this REQUEST_UPDATE pub id: u64, - /// The request ID of the SUBSCRIBE this message is updating. - pub subscription_request_id: u64, + /// The request ID of the existing request this message is updating. + pub existing_request_id: u64, - /// The starting location - pub start_location: Location, - /// The end Group ID, plus 1. A value of 0 means the subscription is open-ended. - pub end_group_id: u64, - - /// Subscriber Priority - pub subscriber_priority: u8, - - /// Forward Flag - pub forward: bool, - - /// Optional parameters + /// Parameters to update (draft-16 Section 9.2.2). + /// Parameters not present remain unchanged from the original request. pub params: KeyValuePairs, } impl Decode for SubscribeUpdate { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; - - let subscription_request_id = u64::decode(r)?; - - let start_location = Location::decode(r)?; - let end_group_id = u64::decode(r)?; - - let subscriber_priority = u8::decode(r)?; - - let forward = bool::decode(r)?; - + let existing_request_id = u64::decode(r)?; let params = KeyValuePairs::decode(r)?; Ok(Self { id, - subscription_request_id, - start_location, - end_group_id, - subscriber_priority, - forward, + existing_request_id, params, }) } @@ -56,16 +34,7 @@ impl Decode for SubscribeUpdate { impl Encode for SubscribeUpdate { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; - - self.subscription_request_id.encode(w)?; - - self.start_location.encode(w)?; - self.end_group_id.encode(w)?; - - self.subscriber_priority.encode(w)?; - - self.forward.encode(w)?; - + self.existing_request_id.encode(w)?; self.params.encode(w)?; Ok(()) @@ -81,21 +50,30 @@ mod tests { fn encode_decode() { let mut buf = BytesMut::new(); - // One parameter for testing let mut kvps = KeyValuePairs::new(); kvps.set_intvalue(124, 456); let msg = SubscribeUpdate { id: 1000, - subscription_request_id: 924, - start_location: Location::new(1, 1), - end_group_id: 100000, - subscriber_priority: 127, - forward: true, + existing_request_id: 924, params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); let decoded = SubscribeUpdate::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); } + + #[test] + fn encode_decode_empty_params() { + let mut buf = BytesMut::new(); + + let msg = SubscribeUpdate { + id: 5, + existing_request_id: 3, + params: KeyValuePairs::new(), + }; + msg.encode(&mut buf).unwrap(); + let decoded = SubscribeUpdate::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); + } } diff --git a/moq-transport/src/message/subscriber.rs b/moq-transport/src/message/subscriber.rs index 0a11fb9e..3c433149 100644 --- a/moq-transport/src/message/subscriber.rs +++ b/moq-transport/src/message/subscriber.rs @@ -53,10 +53,8 @@ subscriber_msgs! { FetchCancel, TrackStatus, SubscribeNamespace, - UnsubscribeNamespace, PublishNamespaceCancel, - PublishNamespaceOk, - PublishNamespaceError, + RequestOk, PublishOk, - PublishError, + RequestError, } diff --git a/moq-transport/src/message/track_status_error.rs b/moq-transport/src/message/track_status_error.rs deleted file mode 100644 index 7b015ea3..00000000 --- a/moq-transport/src/message/track_status_error.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct TrackStatusError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for TrackStatusError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for TrackStatusError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/unsubscribe_namespace.rs b/moq-transport/src/message/unsubscribe_namespace.rs deleted file mode 100644 index de257378..00000000 --- a/moq-transport/src/message/unsubscribe_namespace.rs +++ /dev/null @@ -1,42 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, TrackNamespace}; - -/// Unsubscribe Namespace -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct UnsubscribeNamespace { - // Echo back the track namespace prefix from subscribe namespace - pub track_namespace_prefix: TrackNamespace, -} - -impl Decode for UnsubscribeNamespace { - fn decode(r: &mut R) -> Result { - let track_namespace_prefix = TrackNamespace::decode(r)?; - Ok(Self { - track_namespace_prefix, - }) - } -} - -impl Encode for UnsubscribeNamespace { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.track_namespace_prefix.encode(w)?; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::BytesMut; - - #[test] - fn encode_decode() { - let mut buf = BytesMut::new(); - - let msg = UnsubscribeNamespace { - track_namespace_prefix: TrackNamespace::from_utf8_path("test/path/to/resource"), - }; - msg.encode(&mut buf).unwrap(); - let decoded = UnsubscribeNamespace::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - } -} diff --git a/moq-transport/src/mlog/events.rs b/moq-transport/src/mlog/events.rs index 53c513f0..f074b747 100644 --- a/moq-transport/src/mlog/events.rs +++ b/moq-transport/src/mlog/events.rs @@ -2,10 +2,10 @@ // - SubscribeUpdate (parsed/created) // - PublishNamespaceDone (parsed/created) // - PublishNamespaceCancel (parsed/created) -// - TrackStatus, TrackStatusOk, TrackStatusError (parsed/created) -// - SubscribeNamespace, SubscribeNamespaceOk, SubscribeNamespaceError, UnsubscribeNamespace (parsed/created) -// - Fetch, FetchOk, FetchError, FetchCancel (parsed/created) -// - Publish, PublishOk, PublishError, PublishDone (parsed/created) +// - TrackStatus, TrackStatusOk (parsed/created) +// - SubscribeNamespace (parsed/created) +// - Fetch, FetchOk, FetchCancel (parsed/created) +// - Publish, PublishOk, PublishDone (parsed/created) // - MaxRequestId (parsed/created) // - RequestsBlocked (parsed/created) // @@ -207,7 +207,6 @@ fn create_control_message_event( /// Create a control_message_parsed event for CLIENT_SETUP pub fn client_setup_parsed(time: f64, stream_id: u64, msg: &setup::Client) -> Event { - let versions: Vec = msg.versions.0.iter().map(|v| format!("{:?}", v)).collect(); create_control_message_event( time, stream_id, @@ -215,8 +214,6 @@ pub fn client_setup_parsed(time: f64, stream_id: u64, msg: &setup::Client) -> Ev "client_setup", json!( { - "number_of_supported_versions": msg.versions.0.len(), - "supported_versions": versions, "parameters": key_value_pairs_to_vec(&msg.params.0), }), ) @@ -231,7 +228,6 @@ pub fn server_setup_created(time: f64, stream_id: u64, msg: &setup::Server) -> E "server_setup", json!( { - "selected_version": format!("{:?}", msg.version), "parameters": key_value_pairs_to_vec(&msg.params.0), }), ) @@ -239,25 +235,12 @@ pub fn server_setup_created(time: f64, stream_id: u64, msg: &setup::Server) -> E /// Helper to convert SUBSCRIBE message to JSON fn subscribe_to_json(msg: &message::Subscribe) -> JsonValue { - let mut json = json!({ + let json = json!({ "subscribe_id": msg.id, "track_namespace": msg.track_namespace.to_string(), "track_name": &msg.track_name, - "subscriber_priority": msg.subscriber_priority, - "group_order": format!("{:?}", msg.group_order), - "filter_type": format!("{:?}", msg.filter_type), "parameters": key_value_pairs_to_vec(&msg.params.0), }); - - // Add optional fields based on filter type - if let Some(start_loc) = &msg.start_location { - json["start_group"] = json!(start_loc.group_id); - json["start_object"] = json!(start_loc.object_id); - } - if let Some(end_group) = msg.end_group_id { - json["end_group"] = json!(end_group); - } - json } @@ -273,23 +256,11 @@ pub fn subscribe_created(time: f64, stream_id: u64, msg: &message::Subscribe) -> /// Helper to convert SUBSCRIBE_OK message to JSON fn subscribe_ok_to_json(msg: &message::SubscribeOk) -> JsonValue { - let mut json = json!({ + let json = json!({ "subscribe_id": msg.id, "track_alias": msg.track_alias, - "expires": msg.expires, - "group_order": format!("{:?}", msg.group_order), - "content_exists": msg.content_exists, "parameters": key_value_pairs_to_vec(&msg.params.0), }); - - // Add optional largest_location fields if content exists - if msg.content_exists { - if let Some(largest) = &msg.largest_location { - json["largest_group_id"] = json!(largest.group_id); - json["largest_object_id"] = json!(largest.object_id); - } - } - json } @@ -316,33 +287,34 @@ pub fn subscribe_ok_created(time: f64, stream_id: u64, msg: &message::SubscribeO } /// Helper to convert SUBSCRIBE_ERROR message to JSON -fn subscribe_error_to_json(msg: &message::SubscribeError) -> JsonValue { +fn request_error_to_json(msg: &message::RequestError) -> JsonValue { json!({ - "subscribe_id": msg.id, + "request_id": msg.id, "error_code": msg.error_code, + "retry_interval": msg.retry_interval, "reason_phrase": &msg.reason_phrase.0, }) } /// Create a control_message_parsed event for SUBSCRIBE_ERROR -pub fn subscribe_error_parsed(time: f64, stream_id: u64, msg: &message::SubscribeError) -> Event { +pub fn request_error_parsed(time: f64, stream_id: u64, msg: &message::RequestError) -> Event { create_control_message_event( time, stream_id, true, - "subscribe_error", - subscribe_error_to_json(msg), + "request_error", + request_error_to_json(msg), ) } /// Create a control_message_created event for SUBSCRIBE_ERROR -pub fn subscribe_error_created(time: f64, stream_id: u64, msg: &message::SubscribeError) -> Event { +pub fn reqeust_error_created(time: f64, stream_id: u64, msg: &message::RequestError) -> Event { create_control_message_event( time, stream_id, false, - "subscribe_error", - subscribe_error_to_json(msg), + "request_error", + request_error_to_json(msg), ) } @@ -386,78 +358,100 @@ pub fn publish_namespace_created( } /// Helper to convert PUBLISH_NAMESPACE_OK message to JSON -fn publish_namespace_ok_to_json(msg: &message::PublishNamespaceOk) -> JsonValue { +fn request_ok_to_json(msg: &message::RequestOk) -> JsonValue { json!({ "request_id": msg.id, }) } -/// Create a control_message_parsed event for PUBLISH_NAMESPACE_OK (was ANNOUNCE_OK) -pub fn publish_namespace_ok_parsed( - time: f64, - stream_id: u64, - msg: &message::PublishNamespaceOk, -) -> Event { +/// Create a control_message_parsed event for REQUEST_OK +pub fn request_ok_parsed(time: f64, stream_id: u64, msg: &message::RequestOk) -> Event { + create_control_message_event(time, stream_id, true, "request_ok", request_ok_to_json(msg)) +} + +/// Create a control_message_created event for Reqeust OK +pub fn reqeust_ok_created(time: f64, stream_id: u64, msg: &message::RequestOk) -> Event { create_control_message_event( time, stream_id, - true, - "publish_namespace_ok", - publish_namespace_ok_to_json(msg), + false, + "request_ok", + request_ok_to_json(msg), ) } -/// Create a control_message_created event for PUBLISH_NAMESPACE_OK -pub fn publish_namespace_ok_created( - time: f64, - stream_id: u64, - msg: &message::PublishNamespaceOk, -) -> Event { +fn publish_to_json(msg: &message::Publish) -> JsonValue { + json!({ + "publish_id": msg.id, + "track_namespace": msg.track_namespace.to_string(), + "track_name": &msg.track_name, + "track_alias": msg.track_alias, + "parameters": key_value_pairs_to_vec(&msg.params.0), + }) +} + +/// Create a control_message_parsed event for PUBLISH +pub fn publish_parsed(time: f64, stream_id: u64, msg: &message::Publish) -> Event { + create_control_message_event(time, stream_id, true, "publish", publish_to_json(msg)) +} + +/// Create a control_message_created event for PUBLISH +pub fn publish_created(time: f64, stream_id: u64, msg: &message::Publish) -> Event { + create_control_message_event(time, stream_id, false, "publish", publish_to_json(msg)) +} + +fn publish_ok_to_json(msg: &message::PublishOk) -> JsonValue { + json!({ + "publish_id": msg.id, + "parameters": key_value_pairs_to_vec(&msg.params.0), + }) +} + +/// Create a control_message_parsed event for PUBLISH_OK +pub fn publish_ok_parsed(time: f64, stream_id: u64, msg: &message::PublishOk) -> Event { + create_control_message_event(time, stream_id, true, "publish_ok", publish_ok_to_json(msg)) +} + +/// Create a control_message_created event for PUBLISH_OK +pub fn publish_ok_created(time: f64, stream_id: u64, msg: &message::PublishOk) -> Event { create_control_message_event( time, stream_id, false, - "publish_namespace_ok", - publish_namespace_ok_to_json(msg), + "publish_ok", + publish_ok_to_json(msg), ) } -/// Helper to convert PUBLISH_NAMESPACE_ERROR message to JSON -fn publish_namespace_error_to_json(msg: &message::PublishNamespaceError) -> JsonValue { +/// Helper to convert PUBLISH_DONE message to JSON +fn publish_done_to_json(msg: &message::PublishDone) -> JsonValue { json!({ - "request_id": msg.id, - "error_code": msg.error_code, - "reason_phrase": &msg.reason_phrase.0, + "publish_id": msg.id, + "status_code": msg.status_code, + "stream_count": msg.stream_count, + "reason": &msg.reason.0, }) } -/// Create a control_message_parsed event for PUBLISH_NAMESPACE_ERROR (was ANNOUNCE_ERROR) -pub fn publish_namespace_error_parsed( - time: f64, - stream_id: u64, - msg: &message::PublishNamespaceError, -) -> Event { +/// Create a control_message_parsed event for PUBLISH_DONE +pub fn publish_done_parsed(time: f64, stream_id: u64, msg: &message::PublishDone) -> Event { create_control_message_event( time, stream_id, true, - "publish_namespace_error", - publish_namespace_error_to_json(msg), + "publish_done", + publish_done_to_json(msg), ) } -/// Create a control_message_created event for PUBLISH_NAMESPACE_ERROR -pub fn publish_namespace_error_created( - time: f64, - stream_id: u64, - msg: &message::PublishNamespaceError, -) -> Event { +/// Create a control_message_created event for PUBLISH_DONE +pub fn publish_done_created(time: f64, stream_id: u64, msg: &message::PublishDone) -> Event { create_control_message_event( time, stream_id, false, - "publish_namespace_error", - publish_namespace_error_to_json(msg), + "publish_done", + publish_done_to_json(msg), ) } diff --git a/moq-transport/src/serve/datagram.rs b/moq-transport/src/serve/datagram.rs index f7ff2697..1eb07e73 100644 --- a/moq-transport/src/serve/datagram.rs +++ b/moq-transport/src/serve/datagram.rs @@ -1,118 +1,106 @@ use std::{fmt, sync::Arc}; -use crate::watch::State; +use tokio::sync::broadcast; use super::{ServeError, Track}; +const DATAGRAM_CHANNEL_SIZE: usize = 4096; + pub struct Datagrams { pub track: Arc, } impl Datagrams { pub fn produce(self) -> (DatagramsWriter, DatagramsReader) { - let (writer, reader) = State::default().split(); + let (tx, rx) = broadcast::channel(DATAGRAM_CHANNEL_SIZE); - let writer = DatagramsWriter::new(writer, self.track.clone()); - let reader = DatagramsReader::new(reader, self.track); + // Keep a reference to the sender in the reader so clones get fresh receivers + let tx_for_reader = tx.clone(); + let writer = DatagramsWriter::new(tx, self.track.clone()); + let reader = DatagramsReader::new(rx, tx_for_reader, self.track); (writer, reader) } } -struct DatagramsState { - // The latest datagram - latest: Option, - - // Increased each time datagram changes. - epoch: u64, - - // Set when the writer or all readers are dropped. - closed: Result<(), ServeError>, -} - -impl Default for DatagramsState { - fn default() -> Self { - Self { - latest: None, - epoch: 0, - closed: Ok(()), - } - } -} - pub struct DatagramsWriter { - state: State, + tx: broadcast::Sender, pub track: Arc, } impl DatagramsWriter { - fn new(state: State, track: Arc) -> Self { - Self { state, track } + fn new(tx: broadcast::Sender, track: Arc) -> Self { + Self { tx, track } } pub fn write(&mut self, datagram: Datagram) -> Result<(), ServeError> { - let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; - - state.latest = Some(datagram); - state.epoch += 1; - + // Ignore send errors (no receivers) - datagrams are fire-and-forget + let _ = self.tx.send(datagram); Ok(()) } - pub fn close(self, err: ServeError) -> Result<(), ServeError> { - let state = self.state.lock(); - state.closed.clone()?; - - let mut state = state.into_mut().ok_or(ServeError::Cancel)?; - state.closed = Err(err); - + pub fn close(self, _err: ServeError) -> Result<(), ServeError> { + // Channel closes when tx is dropped Ok(()) } } -#[derive(Clone)] pub struct DatagramsReader { - state: State, + rx: broadcast::Receiver, + tx: broadcast::Sender, pub track: Arc, + latest: Option<(u64, u64)>, +} - epoch: u64, +impl Clone for DatagramsReader { + fn clone(&self) -> Self { + // Subscribe to get a NEW receiver that will get all FUTURE datagrams + // This is correct for relay: each subscriber gets datagrams from now on + Self { + rx: self.tx.subscribe(), + tx: self.tx.clone(), + track: self.track.clone(), + latest: self.latest, + } + } } impl DatagramsReader { - fn new(state: State, track: Arc) -> Self { + fn new(rx: broadcast::Receiver, tx: broadcast::Sender, track: Arc) -> Self { Self { - state, + rx, + tx, track, - epoch: 0, + latest: None, } } pub async fn read(&mut self) -> Result, ServeError> { loop { - { - let state = self.state.lock(); - if self.epoch < state.epoch { - self.epoch = state.epoch; - return Ok(state.latest.clone()); + match self.rx.recv().await { + Ok(datagram) => { + self.latest = Some((datagram.group_id, datagram.object_id)); + return Ok(Some(datagram)); } - - state.closed.clone()?; - match state.modified() { - Some(notify) => notify, - None => return Ok(None), // No more updates will come + Err(broadcast::error::RecvError::Lagged(n)) => { + log::warn!("[DATAGRAMS] reader lagged by {} datagrams", n); + // Continue reading - we'll get the next available datagram + } + Err(broadcast::error::RecvError::Closed) => { + return Ok(None); // Channel closed } } - .await; } } - // Returns the largest group/sequence pub fn latest(&self) -> Option<(u64, u64)> { - let state = self.state.lock(); - state - .latest - .as_ref() - .map(|datagram| (datagram.group_id, datagram.object_id)) + self.latest + } + + pub fn is_closed(&self) -> bool { + // Check if sender is gone (receiver_count would be 0 or send would fail) + // But we can't easily check this, so return false (conservative) + false } } @@ -126,6 +114,9 @@ pub struct Datagram { // Extension headers (for draft-14 compliance, particularly immutable extensions) pub extension_headers: crate::data::ExtensionHeaders, + + // Object status (e.g., EndOfGroup) + pub status: Option, } impl fmt::Debug for Datagram { @@ -136,6 +127,7 @@ impl fmt::Debug for Datagram { .field("priority", &self.priority) .field("payload", &self.payload.len()) .field("extension_headers", &self.extension_headers) + .field("status", &self.status) .finish() } } diff --git a/moq-transport/src/serve/error.rs b/moq-transport/src/serve/error.rs index bb0995b5..57666d3a 100644 --- a/moq-transport/src/serve/error.rs +++ b/moq-transport/src/serve/error.rs @@ -36,6 +36,10 @@ pub enum ServeError { #[error("not implemented: {0} [error:{1}]")] NotImplementedWithId(String, uuid::Uuid), + + /// Relay already has an active SUBSCRIBE path, not interested in PUBLISH + #[error("uninterested")] + Uninterested, } impl ServeError { @@ -60,6 +64,8 @@ impl ServeError { Self::NotImplemented(_) | Self::NotImplementedWithId(_, _) => 0x3, // INTERNAL_ERROR (0x0) - per-request error registries use 0x0 Self::Internal(_) | Self::InternalWithId(_, _) => 0x0, + // UNINTERESTED (0x1) - relay already has data path via SUBSCRIBE + Self::Uninterested => 0x1, } } diff --git a/moq-transport/src/serve/stream.rs b/moq-transport/src/serve/stream.rs index e56c1405..020e2aba 100644 --- a/moq-transport/src/serve/stream.rs +++ b/moq-transport/src/serve/stream.rs @@ -188,6 +188,11 @@ impl StreamReader { ) }) } + + pub fn is_closed(&self) -> bool { + let state = self.state.lock(); + state.closed.is_err() || state.modified().is_none() + } } impl Deref for StreamReader { diff --git a/moq-transport/src/serve/subgroup.rs b/moq-transport/src/serve/subgroup.rs index 2d0fc0c0..e47048b6 100644 --- a/moq-transport/src/serve/subgroup.rs +++ b/moq-transport/src/serve/subgroup.rs @@ -95,6 +95,7 @@ impl SubgroupsWriter { group_id, subgroup_id, priority, + header_type: None, }) } @@ -105,6 +106,7 @@ impl SubgroupsWriter { group_id: subgroup.group_id, subgroup_id: subgroup.subgroup_id, priority: subgroup.priority, + header_type: subgroup.header_type, }; let (writer, reader) = subgroup.produce(); @@ -114,8 +116,17 @@ impl SubgroupsWriter { // TODO: Check this logic again if writer.group_id.cmp(&latest.group_id) == cmp::Ordering::Equal { match writer.subgroup_id.cmp(&latest.subgroup_id) { - cmp::Ordering::Less => return Ok(writer), // dropped immediately, lul - cmp::Ordering::Equal => return Err(ServeError::Duplicate), + cmp::Ordering::Less => return Ok(writer), // dropped immediately + cmp::Ordering::Equal => { + // Duplicate subgroup - silently drop instead of erroring + // This can happen with SubgroupZeroIdEndOfGroup streams + log::warn!( + "duplicate subgroup: group_id={}, subgroup_id={} - dropping", + writer.group_id, + writer.subgroup_id + ); + return Ok(writer); // writer dropped, data lost but relay continues + } cmp::Ordering::Greater => state.latest_subgroup_reader = Some(reader), } } else if writer.group_id.cmp(&latest.group_id) == cmp::Ordering::Greater { @@ -199,6 +210,12 @@ impl SubgroupsReader { .as_ref() .map(|group| (group.group_id, group.latest())) } + + /// Check if the subgroups writer has been closed or dropped. + pub fn is_closed(&self) -> bool { + let state = self.state.lock(); + state.closed.is_err() || state.modified().is_none() + } } impl Deref for SubgroupsReader { @@ -222,6 +239,9 @@ pub struct Subgroup { // The priority of the group within the track. pub priority: u8, + + // The stream header type used for this subgroup (preserved from incoming stream) + pub header_type: Option, } /// Static information about the group @@ -239,6 +259,9 @@ pub struct SubgroupInfo { // The priority of the group within the track. pub priority: u8, + + // The stream header type used for this subgroup (preserved from incoming stream) + pub header_type: Option, } impl SubgroupInfo { @@ -313,11 +336,21 @@ impl SubgroupWriter { &mut self, size: usize, extension_headers: Option, + ) -> Result { + self.create_with_status(size, extension_headers, ObjectStatus::NormalObject) + } + + /// Write an object with a specific status (e.g., EndOfGroup). + pub fn create_with_status( + &mut self, + size: usize, + extension_headers: Option, + status: ObjectStatus, ) -> Result { let (writer, reader) = SubgroupObject { group: self.info.clone(), object_id: self.next_object_id, - status: ObjectStatus::NormalObject, + status, size, extension_headers: extension_headers.unwrap_or_default(), } @@ -331,6 +364,16 @@ impl SubgroupWriter { Ok(writer) } + /// Write an EndOfGroup marker object to signal the end of this subgroup. + /// This should be called when the group is complete. + pub fn end_of_group(&mut self) -> Result<(), ServeError> { + // Create an object with size=0 and status=EndOfGroup + let object_writer = self.create_with_status(0, None, ObjectStatus::EndOfGroup)?; + // Object writer with size=0 will complete immediately when dropped + drop(object_writer); + Ok(()) + } + /// Close the stream with an error. pub fn close(self, err: ServeError) -> Result<(), ServeError> { let state = self.state.lock(); diff --git a/moq-transport/src/serve/track.rs b/moq-transport/src/serve/track.rs index 9dd9e101..01591ca8 100644 --- a/moq-transport/src/serve/track.rs +++ b/moq-transport/src/serve/track.rs @@ -199,10 +199,24 @@ impl TrackReader { /// This is used to detect stale cached TrackReaders that should not be reused. pub fn is_closed(&self) -> bool { let state = self.state.lock(); - // Track is closed if: - // 1. It was explicitly closed with an error, OR - // 2. The writer side has been dropped (modified() returns None) - state.closed.is_err() || state.modified().is_none() + + if state.closed.is_err() { + return true; + } + + // Clone the mode out before dropping the TrackState lock to avoid + // nested lock deadlocks (mode readers hold their own State locks). + if let Some(mode) = state.reader_mode.clone() { + // Mode has been set — the TrackWriter was consumed during the + // Track→Subgroups/Stream/Datagrams transition. Liveness is now + // determined by whether the mode-level writer is still alive. + drop(state); + return mode.is_closed(); + } + + // No mode set yet — check if the writer was abandoned before + // transitioning to a specific mode. + state.modified().is_none() } } @@ -234,6 +248,12 @@ macro_rules! track_readers { $(Self::$name(reader) => reader.latest(),)* } } + + pub fn is_closed(&self) -> bool { + match self { + $(Self::$name(reader) => reader.is_closed(),)* + } + } } } } @@ -266,3 +286,151 @@ macro_rules! track_writers { } track_writers!(Track, Stream, Subgroups, Objects, Datagrams,); + +#[cfg(test)] +mod tests { + use super::*; + use crate::coding::TrackNamespace; + use crate::serve::Subgroup; + + #[test] + fn test_is_closed_false_before_mode_set() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (_writer, reader) = track.produce(); + assert!(!reader.is_closed()); + } + + #[test] + fn test_is_closed_true_when_writer_dropped_without_mode() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + drop(writer); + assert!(reader.is_closed()); + } + + #[test] + fn test_is_closed_true_when_explicitly_closed() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + writer.close(ServeError::Cancel).unwrap(); + assert!(reader.is_closed()); + } + + #[test] + fn test_is_closed_false_after_subgroups_transition_while_writer_alive() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + + let _subgroups_writer = writer.subgroups().expect("subgroups transition should succeed"); + + assert!( + !reader.is_closed(), + "track should NOT be closed while SubgroupsWriter is alive" + ); + } + + #[test] + fn test_is_closed_true_after_subgroups_writer_dropped() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + + let subgroups_writer = writer.subgroups().expect("subgroups transition should succeed"); + drop(subgroups_writer); + + assert!( + reader.is_closed(), + "track should be closed after SubgroupsWriter is dropped" + ); + } + + #[test] + fn test_is_closed_true_after_subgroups_writer_explicitly_closed() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + + let subgroups_writer = writer.subgroups().expect("subgroups transition should succeed"); + subgroups_writer.close(ServeError::Cancel).unwrap(); + + assert!( + reader.is_closed(), + "track should be closed after SubgroupsWriter is explicitly closed" + ); + } + + #[test] + fn test_is_closed_false_after_stream_transition_while_writer_alive() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + + let _stream_writer = writer.stream(0).expect("stream transition should succeed"); + + assert!( + !reader.is_closed(), + "track should NOT be closed while StreamWriter is alive" + ); + } + + #[test] + fn test_is_closed_true_after_stream_writer_dropped() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + + let stream_writer = writer.stream(0).expect("stream transition should succeed"); + drop(stream_writer); + + assert!( + reader.is_closed(), + "track should be closed after StreamWriter is dropped" + ); + } + + #[test] + fn test_is_closed_false_after_datagrams_transition_while_writer_alive() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + + let _datagrams_writer = writer.datagrams().expect("datagrams transition should succeed"); + + assert!( + !reader.is_closed(), + "track should NOT be closed while DatagramsWriter is alive" + ); + } + + #[test] + fn test_is_closed_true_after_datagrams_writer_dropped() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + + let datagrams_writer = writer.datagrams().expect("datagrams transition should succeed"); + drop(datagrams_writer); + + assert!( + reader.is_closed(), + "track should be closed after DatagramsWriter is dropped" + ); + } + + #[test] + fn test_is_closed_false_while_subgroups_actively_writing() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + + let mut subgroups_writer = + writer.subgroups().expect("subgroups transition should succeed"); + + let _subgroup_writer = subgroups_writer + .create(Subgroup { + group_id: 0, + subgroup_id: 0, + priority: 0, + header_type: None, + }) + .expect("create subgroup should succeed"); + + assert!( + !reader.is_closed(), + "track should NOT be closed while actively writing subgroups" + ); + } +} diff --git a/moq-transport/src/serve/tracks.rs b/moq-transport/src/serve/tracks.rs index 7ce3aef0..cea6e954 100644 --- a/moq-transport/src/serve/tracks.rs +++ b/moq-transport/src/serve/tracks.rs @@ -91,6 +91,21 @@ impl TracksWriter { }; self.state.lock_mut()?.tracks.remove(&full_name) } + + /// Insert an existing track reader into the broadcast. + /// Returns None if all readers have been dropped or if a track with this name already exists. + pub fn insert(&mut self, reader: TrackReader) -> Option<()> { + let full_name = FullTrackName { + namespace: reader.namespace.clone(), + name: reader.name.clone(), + }; + let mut state = self.state.lock_mut()?; + if state.tracks.contains_key(&full_name) { + return None; + } + state.tracks.insert(full_name, reader); + Some(()) + } } impl Deref for TracksWriter { @@ -201,7 +216,6 @@ impl TracksReader { return Some(track_reader.clone()); } // Track is closed/stale, fall through to create a new one - // We'll remove the stale entry and request a fresh track from the publisher } let mut state = state.into_mut()?; @@ -226,6 +240,13 @@ impl TracksReader { Some(track_writer_reader.1) } + + /// Forward an existing track writer to the upstream subscription queue. + /// The writer will be received by [TracksRequest::next()]. + /// Returns None if the queue is closed. + pub fn forward_upstream(&mut self, writer: TrackWriter) -> Option<()> { + self.queue.push(writer).ok() + } } impl Deref for TracksReader { @@ -324,6 +345,73 @@ mod tests { ); } + #[tokio::test] + async fn test_track_not_stale_after_subgroups_transition() { + let namespace = TrackNamespace::from_utf8_path("test/namespace"); + let track_name = "test-track"; + + let (_writer, mut request, mut reader) = Tracks::new(namespace.clone()).produce(); + + let _track_reader_1 = reader + .subscribe(namespace.clone(), track_name) + .expect("first subscribe should succeed"); + + let track_writer = request + .next() + .await + .expect("publisher should receive track request"); + + let _subgroups_writer = track_writer + .subgroups() + .expect("subgroups transition should succeed"); + + let _track_reader_2 = reader + .subscribe(namespace.clone(), track_name) + .expect("second subscribe should succeed"); + + let maybe_second_request = + tokio::time::timeout(std::time::Duration::from_millis(100), request.next()).await; + + assert!( + maybe_second_request.is_err(), + "publisher should NOT get a second request while SubgroupsWriter is alive" + ); + } + + #[tokio::test] + async fn test_track_stale_after_subgroups_writer_dropped() { + let namespace = TrackNamespace::from_utf8_path("test/namespace"); + let track_name = "test-track"; + + let (_writer, mut request, mut reader) = Tracks::new(namespace.clone()).produce(); + + let _track_reader_1 = reader + .subscribe(namespace.clone(), track_name) + .expect("first subscribe should succeed"); + + let track_writer = request + .next() + .await + .expect("publisher should receive track request"); + + let subgroups_writer = track_writer + .subgroups() + .expect("subgroups transition should succeed"); + drop(subgroups_writer); + + let _track_reader_2 = reader + .subscribe(namespace.clone(), track_name) + .expect("second subscribe should succeed"); + + let maybe_second_request = + tokio::time::timeout(std::time::Duration::from_millis(100), request.next()).await; + + assert!( + maybe_second_request.is_ok(), + "publisher should get a new request after SubgroupsWriter is dropped" + ); + } + /// Test that normal track caching works correctly when tracks are still alive. /// /// Multiple subscribers to the same track should share the same TrackReader diff --git a/moq-transport/src/session/announce.rs b/moq-transport/src/session/announce.rs deleted file mode 100644 index 278c614a..00000000 --- a/moq-transport/src/session/announce.rs +++ /dev/null @@ -1,227 +0,0 @@ -use std::{collections::VecDeque, ops}; - -use crate::coding::TrackNamespace; -use crate::watch::State; -use crate::{message, serve::ServeError}; - -use super::{Publisher, Subscribed, TrackStatusRequested}; - -#[derive(Debug, Clone)] -pub struct AnnounceInfo { - pub request_id: u64, - pub namespace: TrackNamespace, -} - -struct AnnounceState { - subscribers: VecDeque, - track_statuses_requested: VecDeque, - ok: bool, - closed: Result<(), ServeError>, -} - -impl Default for AnnounceState { - fn default() -> Self { - Self { - subscribers: Default::default(), - track_statuses_requested: Default::default(), - ok: false, - closed: Ok(()), - } - } -} - -impl Drop for AnnounceState { - fn drop(&mut self) { - for subscriber in self.subscribers.drain(..) { - subscriber - .close(ServeError::not_found_ctx( - "announce dropped before subscription handled", - )) - .ok(); - } - } -} - -#[must_use = "unannounce on drop"] -pub struct Announce { - publisher: Publisher, - state: State, - - pub info: AnnounceInfo, -} - -impl Announce { - pub(super) fn new( - mut publisher: Publisher, - request_id: u64, - namespace: TrackNamespace, - ) -> (Announce, AnnounceRecv) { - let info = AnnounceInfo { - request_id, - namespace: namespace.clone(), - }; - - publisher.send_message(message::PublishNamespace { - id: request_id, - track_namespace: namespace.clone(), - params: Default::default(), - }); - - let (send, recv) = State::default().split(); - - let send = Self { - publisher, - info, - state: send, - }; - let recv = AnnounceRecv { - state: recv, - request_id, - }; - - (send, recv) - } - - // Run until we get an error - pub async fn closed(&self) -> Result<(), ServeError> { - loop { - { - let state = self.state.lock(); - state.closed.clone()?; - - match state.modified() { - Some(notified) => notified, - None => return Ok(()), - } - } - .await; - } - } - - /// Wait until a subscriber is received - pub async fn subscribed(&self) -> Result, ServeError> { - loop { - { - let state = self.state.lock(); - if !state.subscribers.is_empty() { - return Ok(state - .into_mut() - .and_then(|mut state| state.subscribers.pop_front())); - } - - state.closed.clone()?; - match state.modified() { - Some(notified) => notified, - None => return Ok(None), - } - } - .await; - } - } - - pub async fn track_status_requested(&self) -> Result, ServeError> { - loop { - { - let state = self.state.lock(); - if !state.track_statuses_requested.is_empty() { - return Ok(state - .into_mut() - .and_then(|mut state| state.track_statuses_requested.pop_front())); - } - - state.closed.clone()?; - match state.modified() { - Some(notified) => notified, - None => return Ok(None), - } - } - .await; - } - } - - // Wait until an OK is received - pub async fn ok(&self) -> Result<(), ServeError> { - loop { - { - let state = self.state.lock(); - if state.ok { - return Ok(()); - } - state.closed.clone()?; - - match state.modified() { - Some(notified) => notified, - None => return Ok(()), - } - } - .await; - } - } -} - -impl Drop for Announce { - fn drop(&mut self) { - if self.state.lock().closed.is_err() { - return; - } - - self.publisher.send_message(message::PublishNamespaceDone { - track_namespace: self.namespace.clone(), - }); - } -} - -impl ops::Deref for Announce { - type Target = AnnounceInfo; - - fn deref(&self) -> &Self::Target { - &self.info - } -} - -pub(super) struct AnnounceRecv { - state: State, - pub request_id: u64, // TODO SLG - Announcements need to be looked up by both request_id and namespace, consider 2 hashmaps in publisher instead of this -} - -impl AnnounceRecv { - pub fn recv_ok(&mut self) -> Result<(), ServeError> { - if let Some(mut state) = self.state.lock_mut() { - if state.ok { - return Err(ServeError::Duplicate); - } - - state.ok = true; - } - - Ok(()) - } - - pub fn recv_error(self, err: ServeError) -> Result<(), ServeError> { - let state = self.state.lock(); - state.closed.clone()?; - - let mut state = state.into_mut().ok_or(ServeError::Done)?; - state.closed = Err(err); - - Ok(()) - } - - pub fn recv_subscribe(&mut self, subscriber: Subscribed) -> Result<(), ServeError> { - let mut state = self.state.lock_mut().ok_or(ServeError::Done)?; - state.subscribers.push_back(subscriber); - - Ok(()) - } - - pub fn recv_track_status_requested( - &mut self, - track_status_requested: TrackStatusRequested, - ) -> Result<(), ServeError> { - let mut state = self.state.lock_mut().ok_or(ServeError::Done)?; - state - .track_statuses_requested - .push_back(track_status_requested); - Ok(()) - } -} diff --git a/moq-transport/src/session/error.rs b/moq-transport/src/session/error.rs index 25006b9f..2e5ceacb 100644 --- a/moq-transport/src/session/error.rs +++ b/moq-transport/src/session/error.rs @@ -2,14 +2,8 @@ use crate::{coding, serve, setup}; #[derive(thiserror::Error, Debug, Clone)] pub enum SessionError { - #[error("webtransport session: {0}")] - Session(#[from] web_transport::SessionError), - - #[error("webtransport write: {0}")] - Write(#[from] web_transport::WriteError), - - #[error("webtransport read: {0}")] - Read(#[from] web_transport::ReadError), + #[error("webtransport error: {0}")] + WebTransport(#[from] web_transport::Error), #[error("encode error: {0}")] Encode(#[from] coding::EncodeError), @@ -53,9 +47,7 @@ impl SessionError { // PROTOCOL_VIOLATION (0x3) - The role negotiated in the handshake was violated Self::RoleViolation => 0x3, // INTERNAL_ERROR (0x1) - Generic internal errors - Self::Session(_) => 0x1, - Self::Read(_) => 0x1, - Self::Write(_) => 0x1, + Self::WebTransport(_) => 0x1, Self::Encode(_) => 0x1, Self::BoundsExceeded(_) => 0x1, Self::Internal => 0x1, diff --git a/moq-transport/src/session/mod.rs b/moq-transport/src/session/mod.rs index 3b01dbd1..2b6dea17 100644 --- a/moq-transport/src/session/mod.rs +++ b/moq-transport/src/session/mod.rs @@ -1,19 +1,27 @@ -mod announce; -mod announced; mod error; +mod publish_namespace; +mod publish_namespace_received; +mod publish_received; +mod published; mod publisher; mod reader; mod subscribe; +mod subscribe_namespace; +mod subscribe_namespace_received; mod subscribed; mod subscriber; mod track_status_requested; mod writer; -pub use announce::*; -pub use announced::*; pub use error::*; +pub use publish_namespace::*; +pub use publish_namespace_received::*; +pub use publish_received::*; +pub use published::*; pub use publisher::*; pub use subscribe::*; +pub use subscribe_namespace::*; +pub use subscribe_namespace_received::*; pub use subscribed::*; pub use subscriber::*; pub use track_status_requested::*; @@ -52,14 +60,6 @@ pub struct Session { } impl Session { - // Helper for determining the largest supported version - fn largest_common(a: &[T], b: &[T]) -> Option { - a.iter() - .filter(|x| b.contains(x)) // keep only items also in b - .cloned() // clone because we return T, not &T - .max() // take the largest - } - fn new( webtransport: web_transport::Session, sender: Writer, @@ -101,7 +101,7 @@ impl Session { /// Create an outbound/client QUIC connection, by opening a bi-directional QUIC stream for /// MOQT control messaging. Performs SETUP messaging and version negotiation. pub async fn connect( - mut session: web_transport::Session, + session: web_transport::Session, mlog_path: Option, ) -> Result<(Session, Publisher, Subscriber), SessionError> { let mlog = mlog_path.and_then(|path| { @@ -113,16 +113,10 @@ impl Session { let mut sender = Writer::new(control.0); let mut recver = Reader::new(control.1); - let versions: setup::Versions = [setup::Version::DRAFT_14].into(); - - // TODO SLG - make configurable? let mut params = KeyValuePairs::default(); params.set_intvalue(setup::ParameterType::MaxRequestId.into(), 100); - let client = setup::Client { - versions: versions.clone(), - params, - }; + let client = setup::Client { params }; log::debug!("sending CLIENT_SETUP: {:?}", client); sender.encode(&client).await?; @@ -142,7 +136,7 @@ impl Session { /// Accepts an inbound/server QUIC connection, by accepting a bi-directional QUIC stream for /// MOQT control messaging. Performs SETUP messaging and version negotiation. pub async fn accept( - mut session: web_transport::Session, + session: web_transport::Session, mlog_path: Option, ) -> Result<(Session, Option, Option), SessionError> { let mut mlog = mlog_path.and_then(|path| { @@ -163,35 +157,24 @@ impl Session { let _ = mlog.add_event(event); } - let server_versions = setup::Versions(vec![setup::Version::DRAFT_14]); - - if let Some(largest_common_version) = - Self::largest_common(&server_versions, &client.versions) - { - // TODO SLG - make configurable? - let mut params = KeyValuePairs::default(); - params.set_intvalue(setup::ParameterType::MaxRequestId.into(), 100); + // TODO SLG - make configurable? + let mut params = KeyValuePairs::default(); + params.set_intvalue(setup::ParameterType::MaxRequestId.into(), 100); - let server = setup::Server { - version: largest_common_version, - params, - }; + let server = setup::Server { params }; - log::debug!("sending SERVER_SETUP: {:?}", server); + log::debug!("sending SERVER_SETUP: {:?}", server); - // Emit mlog event for SERVER_SETUP created - if let Some(ref mut mlog) = mlog { - let event = mlog::events::server_setup_created(mlog.elapsed_ms(), 0, &server); - let _ = mlog.add_event(event); - } + // Emit mlog event for SERVER_SETUP created + if let Some(ref mut mlog) = mlog { + let event = mlog::events::server_setup_created(mlog.elapsed_ms(), 0, &server); + let _ = mlog.add_event(event); + } - sender.encode(&server).await?; + sender.encode(&server).await?; - // We are the server, so the first request id is 1 - Ok(Session::new(session, sender, recver, 1, mlog)) - } else { - Err(SessionError::Version(client.versions, server_versions)) - } + // We are the server, so the first request id is 1 + Ok(Session::new(session, sender, recver, 1, mlog)) } /// Run Tasks for the session, including sending of control messages, receiving and processing @@ -199,9 +182,10 @@ impl Session { /// and receiving and processing QUIC datagrams received pub async fn run(self) -> Result<(), SessionError> { tokio::select! { - res = Self::run_recv(self.recver, self.publisher, self.subscriber.clone(), self.mlog.clone()) => res, + res = Self::run_recv(self.recver, self.publisher.clone(), self.subscriber.clone(), self.mlog.clone()) => res, res = Self::run_send(self.sender, self.outgoing, self.mlog.clone()) => res, res = Self::run_streams(self.webtransport.clone(), self.subscriber.clone()) => res, + res = Self::run_bidi_streams(self.webtransport.clone(), self.publisher) => res, res = Self::run_datagrams(self.webtransport, self.subscriber) => res, } } @@ -229,8 +213,8 @@ impl Session { Message::SubscribeOk(m) => { Some(mlog::events::subscribe_ok_created(time, stream_id, m)) } - Message::SubscribeError(m) => { - Some(mlog::events::subscribe_error_created(time, stream_id, m)) + Message::RequestError(m) => { + Some(mlog::events::reqeust_error_created(time, stream_id, m)) } Message::Unsubscribe(m) => { Some(mlog::events::unsubscribe_created(time, stream_id, m)) @@ -238,16 +222,22 @@ impl Session { Message::PublishNamespace(m) => { Some(mlog::events::publish_namespace_created(time, stream_id, m)) } - Message::PublishNamespaceOk(m) => Some( - mlog::events::publish_namespace_ok_created(time, stream_id, m), - ), - Message::PublishNamespaceError(m) => Some( - mlog::events::publish_namespace_error_created(time, stream_id, m), - ), + Message::RequestOk(m) => { + Some(mlog::events::reqeust_ok_created(time, stream_id, m)) + } Message::GoAway(m) => { Some(mlog::events::go_away_created(time, stream_id, m)) } - _ => None, // TODO: Add other message types + Message::Publish(m) => { + Some(mlog::events::publish_created(time, stream_id, m)) + } + Message::PublishOk(m) => { + Some(mlog::events::publish_ok_created(time, stream_id, m)) + } + Message::PublishDone(m) => { + Some(mlog::events::publish_done_created(time, stream_id, m)) + } + _ => None, }; if let Some(event) = event { @@ -291,8 +281,8 @@ impl Session { Message::SubscribeOk(m) => { Some(mlog::events::subscribe_ok_parsed(time, stream_id, m)) } - Message::SubscribeError(m) => { - Some(mlog::events::subscribe_error_parsed(time, stream_id, m)) + Message::RequestError(m) => { + Some(mlog::events::request_error_parsed(time, stream_id, m)) } Message::Unsubscribe(m) => { Some(mlog::events::unsubscribe_parsed(time, stream_id, m)) @@ -300,16 +290,22 @@ impl Session { Message::PublishNamespace(m) => { Some(mlog::events::publish_namespace_parsed(time, stream_id, m)) } - Message::PublishNamespaceOk(m) => Some( - mlog::events::publish_namespace_ok_parsed(time, stream_id, m), - ), - Message::PublishNamespaceError(m) => Some( - mlog::events::publish_namespace_error_parsed(time, stream_id, m), - ), + Message::RequestOk(m) => { + Some(mlog::events::request_ok_parsed(time, stream_id, m)) + } Message::GoAway(m) => { Some(mlog::events::go_away_parsed(time, stream_id, m)) } - _ => None, // TODO: Add other message types + Message::Publish(m) => { + Some(mlog::events::publish_parsed(time, stream_id, m)) + } + Message::PublishOk(m) => { + Some(mlog::events::publish_ok_parsed(time, stream_id, m)) + } + Message::PublishDone(m) => { + Some(mlog::events::publish_done_parsed(time, stream_id, m)) + } + _ => None, }; if let Some(event) = event { @@ -318,6 +314,29 @@ impl Session { } } + // RequestOk and RequestError are bidirectional — they can be responses + // to requests originated by either side (e.g., PUBLISH_NAMESPACE from the + // publisher or SUBSCRIBE_NAMESPACE from the subscriber). We must try both + // handlers so the response reaches whichever side owns that request ID. + match &msg { + Message::RequestOk(_) | Message::RequestError(_) => { + // Try subscriber handler first (for SUBSCRIBE_NAMESPACE responses) + if let Ok(pub_msg) = TryInto::::try_into(msg.clone()) { + if let Some(sub) = subscriber.as_mut() { + let _ = sub.recv_message(pub_msg); + } + } + // Also try publisher handler (for PUBLISH_NAMESPACE responses) + if let Ok(sub_msg) = TryInto::::try_into(msg) { + if let Some(pub_) = publisher.as_mut() { + let _ = pub_.recv_message(sub_msg); + } + } + continue; + } + _ => {} + } + let msg = match TryInto::::try_into(msg) { Ok(msg) => { subscriber @@ -353,7 +372,7 @@ impl Session { /// Will read stream header to know what type of stream it is and create /// the appropriate stream handlers. async fn run_streams( - mut webtransport: web_transport::Session, + webtransport: web_transport::Session, subscriber: Option, ) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); @@ -375,9 +394,55 @@ impl Session { } } + /// Accepts bidirectional QUIC streams for messages like SUBSCRIBE_NAMESPACE. + /// In draft-16, SUBSCRIBE_NAMESPACE uses its own bidirectional stream. + async fn run_bidi_streams( + webtransport: web_transport::Session, + publisher: Option, + ) -> Result<(), SessionError> { + let mut tasks = FuturesUnordered::new(); + + loop { + tokio::select! { + res = webtransport.accept_bi() => { + let (_send, recv) = res?; + let mut publisher = publisher.clone().ok_or(SessionError::RoleViolation)?; + + tasks.push(async move { + let mut reader = Reader::new(recv); + + // Read the message from the bidi stream + let msg: message::Message = match reader.decode().await { + Ok(msg) => msg, + Err(e) => { + log::warn!("failed to decode message on bidi stream: {}", e); + return; + } + }; + + log::debug!("received message on bidi stream: {:?}", msg); + + // Handle SUBSCRIBE_NAMESPACE on its dedicated bidi stream + match msg { + Message::SubscribeNamespace(subscribe_ns) => { + if let Err(e) = publisher.recv_message(message::Subscriber::SubscribeNamespace(subscribe_ns)) { + log::warn!("failed to handle SUBSCRIBE_NAMESPACE: {}", e); + } + } + other => { + log::warn!("unexpected message type on bidi stream: {:?}", other); + } + } + }); + }, + _ = tasks.next(), if !tasks.is_empty() => {}, + }; + } + } + /// Receives QUIC datagrams and processes them using the Subscriber logic async fn run_datagrams( - mut webtransport: web_transport::Session, + webtransport: web_transport::Session, mut subscriber: Option, ) -> Result<(), SessionError> { loop { diff --git a/moq-transport/src/session/publish_namespace.rs b/moq-transport/src/session/publish_namespace.rs new file mode 100644 index 00000000..14f7fcee --- /dev/null +++ b/moq-transport/src/session/publish_namespace.rs @@ -0,0 +1,157 @@ +use std::ops; + +use crate::coding::TrackNamespace; +use crate::watch::State; +use crate::{message, serve::ServeError}; + +use super::Publisher; + +#[derive(Debug, Clone)] +pub struct PublishNamespaceInfo { + pub request_id: u64, + pub namespace: TrackNamespace, +} + +/// Internal state for PublishNamespace. +/// +/// PublishNamespace is a namespace registry that advertises to subscribers +/// that a publisher has tracks available in a namespace. It does NOT route +/// subscriptions - that happens via PUBLISH/SUBSCRIBE messages directly. +struct PublishNamespaceState { + ok: bool, + closed: Result<(), ServeError>, +} + +impl Default for PublishNamespaceState { + fn default() -> Self { + Self { + ok: false, + closed: Ok(()), + } + } +} + +/// Represents an outbound PUBLISH_NAMESPACE request (publisher side). +/// When dropped, sends PUBLISH_NAMESPACE_DONE to the peer. +#[must_use = "sends PUBLISH_NAMESPACE_DONE on drop"] +pub struct PublishNamespace { + publisher: Publisher, + state: State, + + pub info: PublishNamespaceInfo, +} + +impl PublishNamespace { + pub(super) fn new( + mut publisher: Publisher, + request_id: u64, + namespace: TrackNamespace, + ) -> (PublishNamespace, PublishNamespaceRecv) { + let info = PublishNamespaceInfo { + request_id, + namespace: namespace.clone(), + }; + + publisher.send_message(message::PublishNamespace { + id: request_id, + track_namespace: namespace.clone(), + params: Default::default(), + }); + + let (send, recv) = State::default().split(); + + let send = Self { + publisher, + info, + state: send, + }; + let recv = PublishNamespaceRecv { + state: recv, + request_id, + }; + + (send, recv) + } + + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } + + pub async fn ok(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + if state.ok { + return Ok(()); + } + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } +} + +impl Drop for PublishNamespace { + fn drop(&mut self) { + if self.state.lock().closed.is_err() { + return; + } + + self.publisher.send_message(message::PublishNamespaceDone { + track_namespace: self.namespace.clone(), + }); + } +} + +impl ops::Deref for PublishNamespace { + type Target = PublishNamespaceInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +pub(super) struct PublishNamespaceRecv { + state: State, + pub request_id: u64, +} + +impl PublishNamespaceRecv { + pub fn recv_ok(&mut self) -> Result<(), ServeError> { + if let Some(mut state) = self.state.lock_mut() { + if state.ok { + return Err(ServeError::Duplicate); + } + + state.ok = true; + } + + Ok(()) + } + + pub fn recv_error(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Done)?; + state.closed = Err(err); + + Ok(()) + } +} diff --git a/moq-transport/src/session/announced.rs b/moq-transport/src/session/publish_namespace_received.rs similarity index 61% rename from moq-transport/src/session/announced.rs rename to moq-transport/src/session/publish_namespace_received.rs index 5b2e466a..a25e6b76 100644 --- a/moq-transport/src/session/announced.rs +++ b/moq-transport/src/session/publish_namespace_received.rs @@ -4,30 +4,30 @@ use crate::coding::{ReasonPhrase, TrackNamespace}; use crate::watch::State; use crate::{message, serve::ServeError}; -use super::{AnnounceInfo, Subscriber}; +use super::{PublishNamespaceInfo, Subscriber}; -// There's currently no feedback from the peer, so the shared state is empty. -// If Unannounce contained an error code then we'd be talking. #[derive(Default)] -struct AnnouncedState {} +struct PublishNamespaceReceivedState {} -pub struct Announced { +/// Represents an inbound PUBLISH_NAMESPACE that was received (subscriber side). +/// When dropped, sends PUBLISH_NAMESPACE_CANCEL (if ok'd) or PUBLISH_NAMESPACE_ERROR. +pub struct PublishNamespaceReceived { session: Subscriber, - state: State, + state: State, - pub info: AnnounceInfo, + pub info: PublishNamespaceInfo, ok: bool, error: Option, } -impl Announced { +impl PublishNamespaceReceived { pub(super) fn new( session: Subscriber, request_id: u64, namespace: TrackNamespace, - ) -> (Announced, AnnouncedRecv) { - let info = AnnounceInfo { + ) -> (PublishNamespaceReceived, PublishNamespaceReceivedRecv) { + let info = PublishNamespaceInfo { request_id, namespace, }; @@ -40,19 +40,19 @@ impl Announced { error: None, state: send, }; - let recv = AnnouncedRecv { _state: recv }; + let recv = PublishNamespaceReceivedRecv { _state: recv }; (send, recv) } - // Send an ANNOUNCE_OK pub fn ok(&mut self) -> Result<(), ServeError> { if self.ok { return Err(ServeError::Duplicate); } - self.session.send_message(message::PublishNamespaceOk { + self.session.send_message(message::RequestOk { id: self.info.request_id, + params: Default::default(), }); self.ok = true; @@ -62,8 +62,6 @@ impl Announced { pub async fn closed(&self) -> Result<(), ServeError> { loop { - // Wow this is dumb and yet pretty cool. - // Basically loop until the state changes and exit when Recv is dropped. self.state .lock() .modified() @@ -78,19 +76,18 @@ impl Announced { } } -impl ops::Deref for Announced { - type Target = AnnounceInfo; +impl ops::Deref for PublishNamespaceReceived { + type Target = PublishNamespaceInfo; - fn deref(&self) -> &AnnounceInfo { + fn deref(&self) -> &PublishNamespaceInfo { &self.info } } -impl Drop for Announced { +impl Drop for PublishNamespaceReceived { fn drop(&mut self) { let err = self.error.clone().unwrap_or(ServeError::Done); - // TODO SLG - ServeError's do not align with draft-13 Announce error codes (section 8.25) if self.ok { self.session.send_message(message::PublishNamespaceCancel { track_namespace: self.namespace.clone(), @@ -98,22 +95,22 @@ impl Drop for Announced { reason_phrase: ReasonPhrase(err.to_string()), }); } else { - self.session.send_message(message::PublishNamespaceError { + self.session.send_message(message::RequestError { id: self.info.request_id, error_code: err.code(), + retry_interval: 0, reason_phrase: ReasonPhrase(err.to_string()), }); } } } -pub(super) struct AnnouncedRecv { - _state: State, +pub(super) struct PublishNamespaceReceivedRecv { + _state: State, } -impl AnnouncedRecv { - pub fn recv_unannounce(self) -> Result<(), ServeError> { - // Will cause the state to be dropped +impl PublishNamespaceReceivedRecv { + pub fn recv_done(self) -> Result<(), ServeError> { Ok(()) } } diff --git a/moq-transport/src/session/publish_received.rs b/moq-transport/src/session/publish_received.rs new file mode 100644 index 00000000..e0703aba --- /dev/null +++ b/moq-transport/src/session/publish_received.rs @@ -0,0 +1,297 @@ +use std::ops; + +use crate::coding::{ReasonPhrase, TrackNamespace}; +use crate::data::ExtensionHeaders; +use crate::serve::ServeError; +use crate::watch::State; +use crate::{data, message, serve}; + +use super::Subscriber; + +#[derive(Debug, Clone)] +pub struct PublishReceivedInfo { + pub id: u64, + pub track_namespace: TrackNamespace, + pub track_name: String, + pub track_alias: u64, + /// Forward parameter from PUBLISH (0x10): true = forward immediately, false = paused + pub forward: bool, + /// Track extensions from the original PUBLISH message + pub track_extensions: ExtensionHeaders, +} + +impl PublishReceivedInfo { + pub fn new_from_publish(msg: &message::Publish) -> Self { + // Forward parameter (0x10): default to true if not present + // Value of 0 means paused, 1 (or non-zero) means forward + let forward = msg + .params + .get_intvalue(0x10) // ParameterType::Forward + .map(|v| v != 0) + .unwrap_or(true); + + Self { + id: msg.id, + track_namespace: msg.track_namespace.clone(), + track_name: msg.track_name.clone(), + track_alias: msg.track_alias, + forward, + track_extensions: msg.track_extensions.clone(), + } + } +} + +struct PublishReceivedState { + ok: bool, + closed: Result<(), ServeError>, + writer: Option, +} + +impl Default for PublishReceivedState { + fn default() -> Self { + Self { + ok: false, + closed: Ok(()), + writer: None, + } + } +} + +#[must_use = "sends PUBLISH_ERROR on drop if not accepted"] +pub struct PublishReceived { + subscriber: Subscriber, + pub info: PublishReceivedInfo, + state: State, + ok: bool, +} + +impl PublishReceived { + pub(super) fn new( + subscriber: Subscriber, + msg: &message::Publish, + ) -> (Self, PublishReceivedRecv) { + let info = PublishReceivedInfo::new_from_publish(msg); + + let (send, recv) = State::default().split(); + + let send = Self { + subscriber, + info, + state: send, + ok: false, + }; + + let recv = PublishReceivedRecv { + state: recv, + writer_mode: None, + }; + + (send, recv) + } + + pub fn accept( + mut self, + track: serve::TrackWriter, + publish_msg: message::PublishOk, + ) -> Result<(), ServeError> { + let state = self.state.lock(); + if state.ok { + return Err(ServeError::Duplicate); + } + state.closed.clone()?; + + self.subscriber.send_message(publish_msg); + + if let Some(mut state) = state.into_mut() { + state.ok = true; + state.writer = Some(track); + } + + self.ok = true; + + std::mem::forget(self); + + Ok(()) + } + + pub fn reject(mut self, error_code: u64, reason: &str) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + self.subscriber.send_message(message::RequestError { + id: self.info.id, + error_code, + retry_interval: 0, + reason_phrase: ReasonPhrase(reason.to_string()), + }); + + if let Some(mut state) = state.into_mut() { + state.closed = Err(ServeError::Closed(error_code)); + } + + std::mem::forget(self); + + Ok(()) + } + + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Done)?; + state.closed = Err(err); + + Ok(()) + } + + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + state.closed.clone()?; + + match state.modified() { + Some(notify) => notify, + None => return Ok(()), + } + } + .await; + } + } +} + +impl ops::Deref for PublishReceived { + type Target = PublishReceivedInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +impl Drop for PublishReceived { + fn drop(&mut self) { + if self.ok { + return; + } + + let state = self.state.lock(); + let err = state + .closed + .as_ref() + .err() + .cloned() + .unwrap_or(ServeError::NotFound); + drop(state); + + self.subscriber.send_message(message::RequestError { + id: self.info.id, + error_code: err.code(), + retry_interval: 0, + reason_phrase: ReasonPhrase(err.to_string()), + }); + } +} + +pub(super) struct PublishReceivedRecv { + state: State, + writer_mode: Option, +} + +impl PublishReceivedRecv { + pub fn track_alias(&self) -> Option { + None + } + + pub fn recv_done(&mut self) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + if let Some(mut state) = state.into_mut() { + state.closed = Err(ServeError::Done); + } + + Ok(()) + } + + fn take_writer(&mut self) -> Result { + if let Some(writer) = self.writer_mode.take() { + return Ok(writer); + } + + let mut state = self.state.lock_mut().ok_or(ServeError::Done)?; + let writer = state.writer.take().ok_or(ServeError::Done)?; + Ok(writer.into()) + } + + fn put_writer(&mut self, writer: serve::TrackWriterMode) { + self.writer_mode = Some(writer); + } + + pub fn subgroup( + &mut self, + header: data::SubgroupHeader, + ) -> Result { + let writer = self.take_writer()?; + + let mut subgroups = match writer { + serve::TrackWriterMode::Track(track) => track.subgroups()?, + serve::TrackWriterMode::Subgroups(subgroups) => subgroups, + _ => return Err(ServeError::Mode), + }; + + let result = subgroups.create(serve::Subgroup { + group_id: header.group_id, + subgroup_id: header.subgroup_id.unwrap_or(0), + priority: header.publisher_priority.unwrap_or(127), + header_type: Some(header.header_type), + }); + + // Always put writer back, even on error, to avoid losing it + self.put_writer(subgroups.into()); + + result + } + + pub fn datagram(&mut self, datagram: data::Datagram) -> Result<(), ServeError> { + let writer = self.take_writer()?; + + // Determine status from datagram type or explicit status field + let status = if datagram.datagram_type.is_end_of_group() { + Some(crate::data::ObjectStatus::EndOfGroup) + } else { + datagram.status + }; + + match writer { + serve::TrackWriterMode::Track(track) => { + let mut datagrams = track.datagrams()?; + datagrams.write(serve::Datagram { + group_id: datagram.group_id, + object_id: datagram.object_id.unwrap_or(0), + priority: datagram.publisher_priority.unwrap_or(127), + payload: datagram.payload.unwrap_or_default(), + extension_headers: datagram.extension_headers.unwrap_or_default(), + status, + })?; + self.put_writer(serve::TrackWriterMode::Datagrams(datagrams)); + Ok(()) + } + serve::TrackWriterMode::Datagrams(mut datagrams) => { + datagrams.write(serve::Datagram { + group_id: datagram.group_id, + object_id: datagram.object_id.unwrap_or(0), + priority: datagram.publisher_priority.unwrap_or(127), + payload: datagram.payload.unwrap_or_default(), + extension_headers: datagram.extension_headers.unwrap_or_default(), + status, + })?; + self.put_writer(serve::TrackWriterMode::Datagrams(datagrams)); + Ok(()) + } + other => { + self.put_writer(other); + Err(ServeError::Mode) + } + } + } +} diff --git a/moq-transport/src/session/published.rs b/moq-transport/src/session/published.rs new file mode 100644 index 00000000..bf004def --- /dev/null +++ b/moq-transport/src/session/published.rs @@ -0,0 +1,633 @@ +use std::ops; +use std::sync::{Arc, Mutex}; + +use futures::stream::FuturesUnordered; +use futures::StreamExt; + +use crate::coding::{Encode, Location, ReasonPhrase, TrackNamespace}; +use crate::message::ParameterType; +use crate::mlog; +use crate::serve::{ServeError, TrackReaderMode}; +use crate::watch::State; +use crate::{data, message, serve}; + +use super::{Publisher, SessionError, Writer}; + +#[derive(Debug, Clone)] +pub struct PublishInfo { + pub id: u64, + pub track_namespace: TrackNamespace, + pub track_name: String, + pub track_alias: u64, +} + +impl PublishInfo { + pub fn new_from_publish(msg: &message::Publish) -> Self { + Self { + id: msg.id, + track_namespace: msg.track_namespace.clone(), + track_name: msg.track_name.clone(), + track_alias: msg.track_alias, + } + } +} + +#[derive(Debug)] +struct PublishedState { + ok: bool, + forward: bool, + subscriber_priority: u8, + group_order: message::GroupOrder, + largest_location: Option, + closed: Result<(), ServeError>, +} + +impl PublishedState { + fn update_largest_location(&mut self, group_id: u64, object_id: u64) -> Result<(), ServeError> { + let new_location = Location::new(group_id, object_id); + if let Some(current) = self.largest_location { + if new_location > current { + self.largest_location = Some(new_location); + } + } else { + self.largest_location = Some(new_location); + } + Ok(()) + } +} + +impl Default for PublishedState { + fn default() -> Self { + Self { + ok: false, + forward: true, + subscriber_priority: 128, + group_order: message::GroupOrder::Ascending, + largest_location: None, + closed: Ok(()), + } + } +} + +#[must_use = "sends PUBLISH_DONE on drop"] +pub struct Published { + publisher: Publisher, + pub info: PublishInfo, + state: State, + ok: bool, + mlog: Option>>, +} + +impl Published { + pub(super) fn new( + mut publisher: Publisher, + msg: message::Publish, + mlog: Option>>, + ) -> (Self, PublishedRecv) { + let info = PublishInfo::new_from_publish(&msg); + + publisher.send_message(msg); + + let (send, recv) = State::default().split(); + + let send = Self { + publisher, + info, + state: send, + ok: false, + mlog, + }; + + let recv = PublishedRecv { state: recv }; + + (send, recv) + } + + pub async fn ok(&mut self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + if state.ok { + self.ok = true; + return Ok(()); + } + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } + + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + state.closed.clone()?; + + match state.modified() { + Some(notify) => notify, + None => return Ok(()), + } + } + .await; + } + } + + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Done)?; + state.closed = Err(err); + + Ok(()) + } + + pub async fn serve(mut self, track: serve::TrackReader) -> Result<(), SessionError> { + let res = self.serve_inner(track).await; + if let Err(err) = &res { + self.close(err.clone().into())?; + } + res + } + + /// Serve using a pre-acquired TrackReaderMode. + /// Use this when you need to acquire the mode early (before network round trips) + /// to avoid missing frames in late-join scenarios. + pub async fn serve_mode(mut self, mode: TrackReaderMode) -> Result<(), SessionError> { + let res = self.serve_mode_inner(mode).await; + if let Err(err) = &res { + self.close(err.clone().into())?; + } + res + } + + /// Serve immediately without waiting for PUBLISH_OK. + /// Use this for relay scenarios where you want to start forwarding data right away. + /// The subscriber will receive data as soon as they're ready. + pub async fn serve_immediately(mut self, track: serve::TrackReader) -> Result<(), SessionError> { + let res = self.serve_immediately_inner(track).await; + if let Err(err) = &res { + self.close(err.clone().into())?; + } + res + } + + async fn serve_inner(&mut self, track: serve::TrackReader) -> Result<(), SessionError> { + self.ok().await?; + + let forward = { + let state = self.state.lock(); + state.forward + }; + + if !forward { + self.closed().await?; + return Ok(()); + } + + match track.mode().await? { + TrackReaderMode::Stream(_stream) => panic!("deprecated"), + TrackReaderMode::Subgroups(subgroups) => self.serve_subgroups(subgroups).await, + TrackReaderMode::Datagrams(datagrams) => self.serve_datagrams(datagrams).await, + } + } + + async fn serve_mode_inner(&mut self, mode: TrackReaderMode) -> Result<(), SessionError> { + self.ok().await?; + + let forward = { + let state = self.state.lock(); + state.forward + }; + + if !forward { + self.closed().await?; + return Ok(()); + } + + match mode { + TrackReaderMode::Stream(_stream) => panic!("deprecated"), + TrackReaderMode::Subgroups(subgroups) => self.serve_subgroups(subgroups).await, + TrackReaderMode::Datagrams(datagrams) => self.serve_datagrams(datagrams).await, + } + } + + async fn serve_immediately_inner(&mut self, track: serve::TrackReader) -> Result<(), SessionError> { + // Don't wait for PUBLISH_OK - start streaming immediately + // This is useful for relay scenarios where we want minimal latency + + match track.mode().await? { + TrackReaderMode::Stream(_stream) => panic!("deprecated"), + TrackReaderMode::Subgroups(subgroups) => self.serve_subgroups(subgroups).await, + TrackReaderMode::Datagrams(datagrams) => self.serve_datagrams(datagrams).await, + } + } + + async fn serve_subgroups( + &mut self, + mut subgroups: serve::SubgroupsReader, + ) -> Result<(), SessionError> { + let mut tasks = FuturesUnordered::new(); + let mut done: Option> = None; + + loop { + tokio::select! { + res = subgroups.next(), if done.is_none() => match res { + Ok(Some(subgroup)) => { + // Header type will be determined in serve_subgroup based on extension headers + let track_alias = self.info.track_alias; + let publisher = self.publisher.clone(); + let state = self.state.clone(); + let info = subgroup.info.clone(); + let mlog = self.mlog.clone(); + + tasks.push(async move { + if let Err(err) = Self::serve_subgroup(track_alias, subgroup, publisher, state, mlog).await { + log::warn!("failed to serve subgroup: {:?}, error: {}", info, err); + } + }); + }, + Ok(None) => done = Some(Ok(())), + Err(err) => done = Some(Err(err)), + }, + res = self.closed(), if done.is_none() => done = Some(res), + _ = tasks.next(), if !tasks.is_empty() => {}, + else => return Ok(done.unwrap()?), + } + } + } + + async fn serve_subgroup( + track_alias: u64, + mut subgroup_reader: serve::SubgroupReader, + mut publisher: Publisher, + state: State, + mlog: Option>>, + ) -> Result<(), SessionError> { + log::debug!( + "[PUBLISHED] serve_subgroup: starting - track_alias={}, group_id={}, subgroup_id={:?}, priority={}", + track_alias, + subgroup_reader.group_id, + subgroup_reader.subgroup_id, + subgroup_reader.priority + ); + + // Read the first object to determine if we have extension headers + let first_object = match subgroup_reader.next().await? { + Some(obj) => obj, + None => { + log::debug!("[PUBLISHED] serve_subgroup: no objects in subgroup, skipping"); + return Ok(()); + } + }; + + // Use preserved header type if available, otherwise determine from extension headers + let has_extension_headers = !first_object.extension_headers.is_empty(); + let header_type = subgroup_reader.info.header_type.unwrap_or_else(|| { + // Fallback: determine header type based on extension headers + if has_extension_headers { + data::StreamHeaderType::SubgroupZeroIdExtEndOfGroup + } else { + data::StreamHeaderType::SubgroupZeroIdEndOfGroup + } + }); + + // If we're not writing extension headers but the preserved header type has extensions, + // convert to the non-Ext variant to avoid mismatch between header and object encoding + let header_type = if !has_extension_headers && header_type.has_extension_headers() { + log::debug!( + "[PUBLISHED] serve_subgroup: converting header_type {:?} to non-Ext variant (objects have no extensions)", + header_type + ); + header_type.without_extensions() + } else { + header_type + }; + + // Set subgroup_id based on header type (ZeroId variants don't include it on wire) + let subgroup_id = if header_type.has_subgroup_id() { + Some(subgroup_reader.subgroup_id) + } else { + None + }; + + let header = data::SubgroupHeader { + header_type, + track_alias, + group_id: subgroup_reader.group_id, + subgroup_id, + publisher_priority: Some(subgroup_reader.priority), + }; + + let mut send_stream = publisher.open_uni().await?; + send_stream.set_priority(subgroup_reader.priority as i32); + + let mut writer = Writer::new(send_stream); + + log::debug!( + "[PUBLISHED] serve_subgroup: sending header - track_alias={}, group_id={}, subgroup_id={:?}, priority={:?}, header_type={:?}, has_ext={}", + header.track_alias, + header.group_id, + header.subgroup_id, + header.publisher_priority, + header.header_type, + has_extension_headers + ); + + writer.encode(&header).await?; + + if let Some(ref mlog) = mlog { + if let Ok(mut mlog_guard) = mlog.lock() { + let time = mlog_guard.elapsed_ms(); + let stream_id = 0; + let event = mlog::subgroup_header_created(time, stream_id, &header); + let _ = mlog_guard.add_event(event); + } + } + + // Helper to write an object + async fn write_object( + writer: &mut Writer, + object_reader: &mut serve::SubgroupObjectReader, + has_extension_headers: bool, + object_count: u64, + subgroup_reader: &serve::SubgroupReader, + state: &State, + mlog: &Option>>, + ) -> Result<(), SessionError> { + if has_extension_headers { + let subgroup_object = data::SubgroupObjectExt { + object_id_delta: 0, + extension_headers: object_reader.extension_headers.clone(), + payload_length: object_reader.size, + status: if object_reader.size == 0 { + Some(object_reader.status) + } else { + None + }, + }; + + log::debug!( + "[PUBLISHED] serve_subgroup: sending object #{} (ext) - object_id={}, payload_length={}, status={:?}", + object_count + 1, + object_reader.object_id, + subgroup_object.payload_length, + subgroup_object.status + ); + + writer.encode(&subgroup_object).await?; + + if let Some(ref mlog) = mlog { + if let Ok(mut mlog_guard) = mlog.lock() { + let time = mlog_guard.elapsed_ms(); + let stream_id = 0; + let event = mlog::subgroup_object_ext_created( + time, + stream_id, + subgroup_reader.group_id, + subgroup_reader.subgroup_id, + object_reader.object_id, + &subgroup_object, + ); + let _ = mlog_guard.add_event(event); + } + } + } else { + let subgroup_object = data::SubgroupObject { + object_id_delta: 0, + payload_length: object_reader.size, + status: if object_reader.size == 0 { + Some(object_reader.status) + } else { + None + }, + }; + + log::debug!( + "[PUBLISHED] serve_subgroup: sending object #{} - object_id={}, payload_length={}, status={:?}", + object_count + 1, + object_reader.object_id, + subgroup_object.payload_length, + subgroup_object.status + ); + + writer.encode(&subgroup_object).await?; + + // No mlog for non-ext objects currently + } + + state + .lock_mut() + .ok_or(ServeError::Done)? + .update_largest_location( + subgroup_reader.group_id, + object_reader.object_id, + )?; + + while let Some(chunk) = object_reader.read().await? { + writer.write(&chunk).await?; + } + + Ok(()) + } + + // Write the first object that we already read + let mut object_count = 0; + let mut first_object = first_object; + write_object( + &mut writer, + &mut first_object, + has_extension_headers, + object_count, + &subgroup_reader, + &state, + &mlog, + ) + .await?; + object_count += 1; + + // Continue with remaining objects + while let Some(mut subgroup_object_reader) = subgroup_reader.next().await? { + write_object( + &mut writer, + &mut subgroup_object_reader, + has_extension_headers, + object_count, + &subgroup_reader, + &state, + &mlog, + ) + .await?; + object_count += 1; + } + + log::info!( + "[PUBLISHED] serve_subgroup: completed subgroup (group_id={}, subgroup_id={:?}, {} objects sent, header_type={:?})", + subgroup_reader.group_id, + subgroup_reader.subgroup_id, + object_count, + header_type + ); + + Ok(()) + } + + async fn serve_datagrams( + &mut self, + mut datagrams: serve::DatagramsReader, + ) -> Result<(), SessionError> { + log::debug!("[PUBLISHED] serve_datagrams: starting"); + + let mut datagram_count = 0; + while let Some(datagram) = datagrams.read().await? { + let has_extension_headers = !datagram.extension_headers.is_empty(); + let datagram_type = if has_extension_headers { + data::DatagramType::ObjectIdPayloadExt + } else { + data::DatagramType::ObjectIdPayload + }; + + let encoded_datagram = data::Datagram { + datagram_type, + track_alias: self.info.track_alias, + group_id: datagram.group_id, + object_id: Some(datagram.object_id), + publisher_priority: Some(datagram.priority), + extension_headers: if has_extension_headers { + Some(datagram.extension_headers.clone()) + } else { + None + }, + status: None, + payload: Some(datagram.payload), + }; + + let payload_len = encoded_datagram + .payload + .as_ref() + .map(|p| p.len()) + .unwrap_or(0); + let mut buffer = bytes::BytesMut::with_capacity(payload_len + 100); + encoded_datagram.encode(&mut buffer)?; + + log::debug!( + "[PUBLISHED] serve_datagrams: sending datagram #{} - track_alias={}, group_id={}, object_id={}, priority={:?}, payload_len={}", + datagram_count + 1, + encoded_datagram.track_alias, + encoded_datagram.group_id, + encoded_datagram.object_id.unwrap(), + encoded_datagram.publisher_priority, + payload_len + ); + + if let Some(ref mlog) = self.mlog { + if let Ok(mut mlog_guard) = mlog.lock() { + let time = mlog_guard.elapsed_ms(); + let stream_id = 0; + let _ = mlog_guard.add_event(mlog::object_datagram_created( + time, + stream_id, + &encoded_datagram, + )); + } + } + + self.publisher.send_datagram(buffer.into()).await?; + + self.state + .lock_mut() + .ok_or(ServeError::Done)? + .update_largest_location( + encoded_datagram.group_id, + encoded_datagram.object_id.unwrap(), + )?; + + datagram_count += 1; + } + + log::info!( + "[PUBLISHED] serve_datagrams: completed ({} datagrams sent)", + datagram_count + ); + + Ok(()) + } +} + +impl ops::Deref for Published { + type Target = PublishInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +impl Drop for Published { + fn drop(&mut self) { + let state = self.state.lock(); + let err = state + .closed + .as_ref() + .err() + .cloned() + .unwrap_or(ServeError::Done); + drop(state); + + self.publisher.send_message(message::PublishDone { + id: self.info.id, + status_code: err.code(), + stream_count: 0, // TODO SLG + reason: ReasonPhrase(err.to_string()), + }); + } +} + +pub(super) struct PublishedRecv { + state: State, +} + +impl PublishedRecv { + pub fn recv_ok(&mut self, msg: &message::PublishOk) -> Result<(), ServeError> { + let state = self.state.lock(); + if state.ok { + return Err(ServeError::Duplicate); + } + + if let Some(mut state) = state.into_mut() { + state.ok = true; + + // Extract subscription properties from parameters (draft-16) + if let Some(v) = msg.params.get_intvalue(ParameterType::Forward.into()) { + state.forward = v == 1; + } + if let Some(v) = msg.params.get_intvalue(ParameterType::SubscriberPriority.into()) { + state.subscriber_priority = v as u8; + } + if let Some(v) = msg.params.get_intvalue(ParameterType::GroupOrder.into()) { + state.group_order = match v { + 0x0 => message::GroupOrder::Publisher, + 0x1 => message::GroupOrder::Ascending, + 0x2 => message::GroupOrder::Descending, + _ => message::GroupOrder::Ascending, + }; + } + } + + Ok(()) + } + + pub fn recv_error(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Done)?; + state.closed = Err(err); + + Ok(()) + } +} diff --git a/moq-transport/src/session/publisher.rs b/moq-transport/src/session/publisher.rs index 1d0c45b1..7303c979 100644 --- a/moq-transport/src/session/publisher.rs +++ b/moq-transport/src/session/publisher.rs @@ -1,55 +1,49 @@ use std::{ - collections::{hash_map, HashMap}, + collections::{hash_map, HashMap, HashSet}, sync::{atomic, Arc, Mutex}, }; -use futures::{stream::FuturesUnordered, StreamExt}; - use crate::{ - coding::TrackNamespace, - message::{self, Message}, + coding::{KeyValuePairs, TrackNamespace}, + message::{self, GroupOrder, Message, ParameterType}, mlog, - serve::{ServeError, TracksReader}, + serve::{self, ServeError, TracksReader}, }; use crate::watch::Queue; use super::{ - Announce, AnnounceRecv, Session, SessionError, Subscribed, SubscribedRecv, TrackStatusRequested, + PublishNamespace, PublishNamespaceRecv, Published, PublishedRecv, Session, SessionError, + SubscribeNamespaceReceived, SubscribeNamespaceReceivedRecv, Subscribed, SubscribedRecv, + TrackStatusRequested, }; -// TODO remove Clone. #[derive(Clone)] pub struct Publisher { webtransport: web_transport::Session, - /// When the announce method is used, a new entry is added to this HashMap to track outbound announcement - announces: Arc>>, + publish_namespaces: Arc>>, + + filtered_namespaces: Arc>>, - /// When a Subscribe is received and we have a previous announce for the namespace, then a new entry is - /// added to this HashMap to track the inbound subscription subscribeds: Arc>>, - /// When a Subscribe is received and we DO NOT have a previous announce for the namespace, then a new entry is - /// added to this Queue to track the inbound subscription unknown_subscribed: Queue, - /// When a TrackStatus is received and we DO NOT have a previous announce for the namespace, then a new entry is - /// added to this Queue to track the inbound track status request unknown_track_status_requested: Queue, - /// The queue we will write any outbound control messages we want to sent, the session run_send task - /// will process the queue and send the message on the control stream. + subscribe_namespaces_received: Arc>>, + + subscribe_namespace_received_queue: Queue, + + publisheds: Arc>>, + + next_track_alias: Arc, + outgoing: Queue, - /// When we need a new Request Id for sending a request, we can get it from here. Note: The instance - /// of AtomicU64 is shared with the Subscriber, so the session uses unique request ids for all requests - /// generated. Note: If we initiated the QUIC connection then request id's start at 0 and increment by 2 - /// for each request (even numbers). If we accepted an inbound QUIC connection then request id's start at 1 and - /// increment by 2 for each request (odd numbers). next_requestid: Arc, - /// Optional mlog writer for logging transport events mlog: Option>>, } @@ -62,16 +56,26 @@ impl Publisher { ) -> Self { Self { webtransport, - announces: Default::default(), + publish_namespaces: Default::default(), + filtered_namespaces: Default::default(), subscribeds: Default::default(), unknown_subscribed: Default::default(), unknown_track_status_requested: Default::default(), + subscribe_namespaces_received: Default::default(), + subscribe_namespace_received_queue: Default::default(), + publisheds: Default::default(), + next_track_alias: Arc::new(atomic::AtomicU64::new(0)), outgoing, next_requestid, mlog, } } + pub fn next_track_alias(&self) -> u64 { + self.next_track_alias + .fetch_add(1, atomic::Ordering::Relaxed) + } + pub async fn accept( session: web_transport::Session, ) -> Result<(Session, Publisher), SessionError> { @@ -86,81 +90,37 @@ impl Publisher { Ok((session, publisher)) } - /// Announce a namespace and serve tracks using the provided [serve::TracksReader]. - /// The caller uses [serve::TracksWriter] for static tracks and [serve::TracksRequest] for dynamic tracks. - pub async fn announce(&mut self, tracks: TracksReader) -> Result<(), SessionError> { - // Check if annouce for this namespace already exists or not, and if not, then create a new Announce - let announce = match self - .announces + pub async fn publish_namespace( + &mut self, + namespace: TrackNamespace, + ) -> Result { + if self + .filtered_namespaces .lock() .unwrap() - .entry(tracks.namespace.clone()) + .contains(&namespace) + { + return Err(ServeError::Cancel.into()); + } + + let publish_ns = match self + .publish_namespaces + .lock() + .unwrap() + .entry(namespace.clone()) { - // Namespace already exists in HashMap (has already been announced) - return Duplicate error hash_map::Entry::Occupied(_) => return Err(ServeError::Duplicate.into()), - // This is a new announce, send announce message to peer. hash_map::Entry::Vacant(entry) => { - // Get the current next request id to use and increment the value for by 2 for the next request let request_id = self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed); - let (send, recv) = - Announce::new(self.clone(), request_id, tracks.namespace.clone()); + let (send, recv) = PublishNamespace::new(self.clone(), request_id, namespace); entry.insert(recv); send } }; - let mut subscribe_tasks = FuturesUnordered::new(); - let mut status_tasks = FuturesUnordered::new(); - let mut subscribe_done = false; - let mut status_done = false; - - // The code enters an infinite loop and waits for one of several events: - // - A new subscription arrives. - // - A new track status request arrives. - // - One of the spawned subscription-handling tasks completes. - // - One of the spawned status-handling tasks completes. - // Exit the loop when all input streams are done (None), and all tasks have completed - loop { - tokio::select! { - // Get next subscription to this announce - res = announce.subscribed(), if !subscribe_done => { - match res? { - Some(subscribed) => { - let tracks = tracks.clone(); - - subscribe_tasks.push(async move { - let info = subscribed.info.clone(); - if let Err(err) = Self::serve_subscribe(subscribed, tracks).await { - log::warn!("failed serving subscribe: {:?}, error: {}", info, err) - } - }); - }, - None => subscribe_done = true, - } - - }, - res = announce.track_status_requested(), if !status_done => { - match res? { - Some(status) => { - let tracks = tracks.clone(); - - status_tasks.push(async move { - let request_msg = status.request_msg.clone(); - if let Err(err) = Self::serve_track_status(status, tracks).await { - log::warn!("failed serving track status request: {:?}, error: {}", request_msg, err) - } - }); - }, - None => status_done = true, - } - }, - Some(res) = subscribe_tasks.next() => res, - Some(res) = status_tasks.next() => res, - else => return Ok(()) - } - } + Ok(publish_ns) } pub async fn serve_subscribe( @@ -206,16 +166,124 @@ impl Publisher { Ok(()) } - // Returns subscriptions that do not map to an active announce. pub async fn subscribed(&mut self) -> Option { self.unknown_subscribed.pop().await } - // Returns track_status requests that do not map to an active announce. pub async fn track_status_requested(&mut self) -> Option { self.unknown_track_status_requested.pop().await } + pub async fn subscribe_namespace_received(&mut self) -> Option { + self.subscribe_namespace_received_queue.pop().await + } + + pub async fn publish(&mut self, track: serve::TrackReader) -> Result { + let request_id = self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed); + let track_alias = self + .next_track_alias + .fetch_add(1, atomic::Ordering::Relaxed); + + let mut params = KeyValuePairs::new(); + params.set_intvalue(ParameterType::GroupOrder.into(), GroupOrder::Ascending as u64); + params.set_intvalue(ParameterType::Forward.into(), 1); + if let Some(loc) = track.largest_location() { + let mut buf = bytes::BytesMut::new(); + use crate::coding::Encode; + loc.encode(&mut buf).ok(); + params.set_bytesvalue(ParameterType::LargestObject.into(), buf.to_vec()); + } + + let msg = message::Publish { + id: request_id, + track_namespace: track.namespace.clone(), + track_name: track.name.clone(), + track_alias, + params, + track_extensions: Default::default(), + }; + + let (send, recv) = Published::new(self.clone(), msg, self.mlog.clone()); + + self.publisheds.lock().unwrap().insert(request_id, recv); + + Ok(send) + } + + pub async fn publish_with_options( + &mut self, + track: serve::TrackReader, + group_order: GroupOrder, + forward: bool, + ) -> Result { + let request_id = self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed); + let track_alias = self + .next_track_alias + .fetch_add(1, atomic::Ordering::Relaxed); + + let mut params = KeyValuePairs::new(); + params.set_intvalue(ParameterType::GroupOrder.into(), group_order as u64); + params.set_intvalue(ParameterType::Forward.into(), if forward { 1 } else { 0 }); + if let Some(loc) = track.largest_location() { + let mut buf = bytes::BytesMut::new(); + use crate::coding::Encode; + loc.encode(&mut buf).ok(); + params.set_bytesvalue(ParameterType::LargestObject.into(), buf.to_vec()); + } + + let msg = message::Publish { + id: request_id, + track_namespace: track.namespace.clone(), + track_name: track.name.clone(), + track_alias, + params, + track_extensions: Default::default(), + }; + + let (send, recv) = Published::new(self.clone(), msg, self.mlog.clone()); + + self.publisheds.lock().unwrap().insert(request_id, recv); + + Ok(send) + } + + /// Publish a track with specific track extensions (for relay forwarding) + pub async fn publish_with_extensions( + &mut self, + track: serve::TrackReader, + track_extensions: crate::data::ExtensionHeaders, + ) -> Result { + let request_id = self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed); + let track_alias = self + .next_track_alias + .fetch_add(1, atomic::Ordering::Relaxed); + + let mut params = KeyValuePairs::new(); + params.set_intvalue(ParameterType::GroupOrder.into(), GroupOrder::Ascending as u64); + params.set_intvalue(ParameterType::Forward.into(), 1); + if let Some(loc) = track.largest_location() { + let mut buf = bytes::BytesMut::new(); + use crate::coding::Encode; + loc.encode(&mut buf).ok(); + params.set_bytesvalue(ParameterType::LargestObject.into(), buf.to_vec()); + } + + let msg = message::Publish { + id: request_id, + track_namespace: track.namespace.clone(), + track_name: track.name.clone(), + track_alias, + params, + track_extensions, + }; + + let (send, recv) = Published::new(self.clone(), msg, self.mlog.clone()); + + self.publisheds.lock().unwrap().insert(request_id, recv); + + Ok(send) + } + pub(crate) fn recv_message(&mut self, msg: message::Subscriber) -> Result<(), SessionError> { let res = match msg { message::Subscriber::Subscribe(msg) => self.recv_subscribe(msg), @@ -226,23 +294,13 @@ impl Publisher { Err(SessionError::unimplemented("FETCH_CANCEL")) } message::Subscriber::TrackStatus(msg) => self.recv_track_status(msg), - message::Subscriber::SubscribeNamespace(_msg) => { - Err(SessionError::unimplemented("SUBSCRIBE_NAMESPACE")) - } - message::Subscriber::UnsubscribeNamespace(_msg) => { - Err(SessionError::unimplemented("UNSUBSCRIBE_NAMESPACE")) - } + message::Subscriber::SubscribeNamespace(msg) => self.recv_subscribe_namespace(msg), message::Subscriber::PublishNamespaceCancel(msg) => { self.recv_publish_namespace_cancel(msg) } - message::Subscriber::PublishNamespaceOk(msg) => self.recv_publish_namespace_ok(msg), - message::Subscriber::PublishNamespaceError(msg) => { - self.recv_publish_namespace_error(msg) - } - message::Subscriber::PublishOk(_msg) => Err(SessionError::unimplemented("PUBLISH_OK")), - message::Subscriber::PublishError(_msg) => { - Err(SessionError::unimplemented("PUBLISH_ERROR")) - } + message::Subscriber::RequestOk(msg) => self.recv_request_ok(msg), + message::Subscriber::PublishOk(msg) => self.recv_publish_ok(msg), + message::Subscriber::RequestError(msg) => self.recv_request_error(msg), }; if let Err(err) = res { @@ -252,42 +310,29 @@ impl Publisher { Ok(()) } - fn recv_publish_namespace_ok( - &mut self, - msg: message::PublishNamespaceOk, - ) -> Result<(), SessionError> { - // We need to find the announce request using the request id, however the self.announces data structure - // is a HashMap indexed by Namespace (which is needed for handling PUBLISH_NAMESPACE_CANCEL). TODO - make more efficient. - // For now iterate through all self.annouces until we find the matching id. - let mut announces = self.announces.lock().unwrap(); - let announce = announces.iter_mut().find(|(_k, v)| v.request_id == msg.id); - - if let Some(announce) = announce { - announce.1.recv_ok()?; + fn recv_request_ok(&mut self, msg: message::RequestOk) -> Result<(), SessionError> { + let mut publish_namespaces = self.publish_namespaces.lock().unwrap(); + let entry = publish_namespaces + .iter_mut() + .find(|(_k, v)| v.request_id == msg.id); + + if let Some(entry) = entry { + entry.1.recv_ok()?; } Ok(()) } - fn recv_publish_namespace_error( - &mut self, - msg: message::PublishNamespaceError, - ) -> Result<(), SessionError> { - // We need to find the announce request using the request id, however the self.announces data structure - // is a HashMap indexed by Namespace (which is needed for handling PUBLISH_NAMESPACE_CANCEL). TODO - make more efficient. - // For now iterate through all self.annouces until we find the matching id. - let mut announces = self.announces.lock().unwrap(); + fn recv_request_error(&mut self, msg: message::RequestError) -> Result<(), SessionError> { + let mut publish_namespaces = self.publish_namespaces.lock().unwrap(); - // Find the key first (immutable borrow only) - let key_opt = announces + let key_opt = publish_namespaces .iter() .find(|(_k, v)| v.request_id == msg.id) .map(|(k, _)| k.clone()); - // Remove from HashMap and take ownership if let Some(key) = key_opt { - if let Some((_ns, v)) = announces.remove_entry(&key) { - // Step 3: call recv_error, consuming v + if let Some((_ns, v)) = publish_namespaces.remove_entry(&key) { v.recv_error(ServeError::Closed(msg.error_code))?; } } @@ -299,10 +344,21 @@ impl Publisher { &mut self, msg: message::PublishNamespaceCancel, ) -> Result<(), SessionError> { - // TODO: If a publisher receives new subscriptions for that namespace after receiving an ANNOUNCE_CANCEL, - // it SHOULD close the session as a 'Protocol Violation'. - if let Some(announce) = self.announces.lock().unwrap().remove(&msg.track_namespace) { - announce.recv_error(ServeError::Cancel)?; + if let Some(entry) = self + .publish_namespaces + .lock() + .unwrap() + .remove(&msg.track_namespace) + { + entry.recv_error(ServeError::Cancel)?; + } + + Ok(()) + } + + fn recv_publish_ok(&mut self, msg: message::PublishOk) -> Result<(), SessionError> { + if let Some(published) = self.publisheds.lock().unwrap().get_mut(&msg.id) { + published.recv_ok(&msg)?; } Ok(()) @@ -314,29 +370,18 @@ impl Publisher { let subscribed = { let mut subscribeds = self.subscribeds.lock().unwrap(); - // See if entry exists for this request id already, if so error out let entry = match subscribeds.entry(msg.id) { hash_map::Entry::Occupied(_) => return Err(SessionError::Duplicate), hash_map::Entry::Vacant(entry) => entry, }; - // Create new Subscribed entry and add to HashMap let (send, recv) = Subscribed::new(self.clone(), msg, self.mlog.clone()); entry.insert(recv); send }; - // If we have an announce, route the subscribe to it. - if let Some(announce) = self.announces.lock().unwrap().get_mut(&namespace) { - return announce.recv_subscribe(subscribed).map_err(Into::into); - } - - // Otherwise, put it in the unknown queue. - // TODO Have some way to detect if the application is not reading from the unknown queue, - // then send SubscribeError. if let Err(err) = self.unknown_subscribed.push(subscribed) { - // Default to closing with a not found error I guess. err.close(ServeError::not_found_ctx(format!( "unknown_subscribed queue full for namespace {:?}", namespace @@ -355,26 +400,12 @@ impl Publisher { } fn recv_track_status(&mut self, msg: message::TrackStatus) -> Result<(), SessionError> { - let namespace = msg.track_namespace.clone(); - - // Create TrackStatusRequested to track this request let track_status_requested = TrackStatusRequested::new(self.clone(), msg); - // If we have an announce, route the track_status to it. - if let Some(announce) = self.announces.lock().unwrap().get_mut(&namespace) { - return announce - .recv_track_status_requested(track_status_requested) - .map_err(Into::into); - } - - // Otherwise, put it in the unknown_track_status queue. - // TODO Have some way to detect if the application is not reading from the unknown_track_status queue, - // then send TrackStatusError. if let Err(mut err) = self .unknown_track_status_requested .push(track_status_requested) { - // push only fails if the queue is dropped, send TrackStatusError, Internal error err.respond_error(0, "Internal error")?; } @@ -389,15 +420,48 @@ impl Publisher { Ok(()) } - /// Process a message before sending it, performing any necessary internal actions. + fn recv_subscribe_namespace( + &mut self, + msg: message::SubscribeNamespace, + ) -> Result<(), SessionError> { + let namespace_prefix = msg.track_namespace_prefix.clone(); + + self.filtered_namespaces + .lock() + .unwrap() + .remove(&namespace_prefix); + + let mut entries = self.subscribe_namespaces_received.lock().unwrap(); + + let entry = match entries.entry(msg.id) { + hash_map::Entry::Occupied(_) => return Err(SessionError::Duplicate), + hash_map::Entry::Vacant(entry) => entry, + }; + + let (send, recv) = + SubscribeNamespaceReceived::new(self.clone(), msg.id, namespace_prefix); + + if let Err(send) = self.subscribe_namespace_received_queue.push(send) { + send.reject(0x0, "Internal error")?; + return Ok(()); + } + + entry.insert(recv); + + Ok(()) + } + fn act_on_message_to_send>( &mut self, msg: T, ) -> message::Publisher { let msg = msg.into(); match &msg { - message::Publisher::PublishDone(m) => self.drop_subscribe(m.id), - message::Publisher::SubscribeError(m) => self.drop_subscribe(m.id), + message::Publisher::PublishDone(m) => { + self.drop_subscribe(m.id); + self.drop_published(m.id); + } + message::Publisher::RequestError(m) => self.drop_subscribe(m.id), message::Publisher::PublishNamespaceDone(m) => { self.drop_publish_namespace(&m.track_namespace); } @@ -429,7 +493,11 @@ impl Publisher { } fn drop_publish_namespace(&mut self, namespace: &TrackNamespace) { - self.announces.lock().unwrap().remove(namespace); + self.publish_namespaces.lock().unwrap().remove(namespace); + } + + fn drop_published(&mut self, id: u64) { + self.publisheds.lock().unwrap().remove(&id); } pub(super) async fn open_uni(&mut self) -> Result { @@ -439,4 +507,16 @@ impl Publisher { pub(super) async fn send_datagram(&mut self, data: bytes::Bytes) -> Result<(), SessionError> { Ok(self.webtransport.send_datagram(data).await?) } + + /// Forward a PUBLISH message to the subscriber (used by relay for SUBSCRIBE_NAMESPACE flow). + /// This sends the message without tracking it for PUBLISH_OK response handling. + pub fn forward_publish(&mut self, msg: message::Publish) { + self.outgoing.push(msg.into()).ok(); + } + + /// Forward a NAMESPACE message to the subscriber (used by relay for SUBSCRIBE_NAMESPACE flow). + /// This announces a namespace that matches the subscriber's SUBSCRIBE_NAMESPACE prefix. + pub fn forward_namespace(&mut self, msg: message::Namespace) { + self.outgoing.push(msg.into()).ok(); + } } diff --git a/moq-transport/src/session/reader.rs b/moq-transport/src/session/reader.rs index 18a6ae16..1dd05530 100644 --- a/moq-transport/src/session/reader.rs +++ b/moq-transport/src/session/reader.rs @@ -68,7 +68,7 @@ impl Reader { // We always read at least once to avoid an infinite loop if some dingus puts remain=0 loop { let before_read = self.buffer.len(); - if !self.stream.read_buf(&mut self.buffer).await? { + if self.stream.read_buf(&mut self.buffer).await?.is_none() { log::warn!( "[READER] decode: stream ended while waiting for data (have={} bytes, need={})", self.buffer.len(), @@ -113,7 +113,7 @@ impl Reader { return Ok(Some(data)); } - let chunk = self.stream.read_chunk(max).await?; + let chunk = self.stream.read(max).await?; if let Some(ref data) = chunk { log::trace!("[READER] read_chunk: read {} bytes from stream", data.len()); } else { @@ -127,6 +127,6 @@ impl Reader { return Ok(false); } - Ok(!self.stream.read_buf(&mut self.buffer).await?) + Ok(self.stream.read_buf(&mut self.buffer).await?.is_none()) } } diff --git a/moq-transport/src/session/subscribe.rs b/moq-transport/src/session/subscribe.rs index b536641b..f9bc1f47 100644 --- a/moq-transport/src/session/subscribe.rs +++ b/moq-transport/src/session/subscribe.rs @@ -1,9 +1,8 @@ use std::ops; use crate::{ - coding::{KeyValuePairs, Location, TrackNamespace}, - data, - message::{self, FilterType, GroupOrder}, + coding::{KeyValuePairs, TrackNamespace}, + data, message, serve::{self, ServeError, TrackWriter, TrackWriterMode}, }; @@ -17,22 +16,6 @@ pub struct SubscribeInfo { pub id: u64, pub track_namespace: TrackNamespace, pub track_name: String, - - /// Subscriber Priority - pub subscriber_priority: u8, - pub group_order: GroupOrder, - - /// Forward Flag - pub forward: bool, - - /// Filter type - pub filter_type: FilterType, - - /// The starting location for this subscription. Only present for "AbsoluteStart" and "AbsoluteRange" filter types. - pub start_location: Option, - /// End group id, inclusive, for the subscription, if applicable. Only present for "AbsoluteRange" filter type. - pub end_group_id: Option, - /// Optional parameters pub params: KeyValuePairs, @@ -46,12 +29,6 @@ impl SubscribeInfo { id: msg.id, track_namespace: msg.track_namespace.clone(), track_name: msg.track_name.clone(), - subscriber_priority: msg.subscriber_priority, - group_order: msg.group_order, - forward: msg.forward, - filter_type: msg.filter_type, - start_location: msg.start_location, - end_group_id: msg.end_group_id, params: msg.params.clone(), track_status: false, } @@ -93,13 +70,6 @@ impl Subscribe { id: request_id, track_namespace: track.namespace.clone(), track_name: track.name.clone(), - // TODO add prioritization logic on the publisher side - subscriber_priority: 127, // default to mid value, see: https://github.com/moq-wg/moq-transport/issues/504 - group_order: GroupOrder::Publisher, // defer to publisher send order - forward: true, // default to forwarding objects - filter_type: FilterType::LargestObject, - start_location: None, - end_group_id: None, params: Default::default(), }; let info = SubscribeInfo::new_from_subscribe(&subscribe_message); @@ -205,16 +175,20 @@ impl SubscribeRecv { _ => return Err(ServeError::Mode), }; - let writer = subgroups.create(serve::Subgroup { + let result = subgroups.create(serve::Subgroup { group_id: header.group_id, // When subgroup_id is not present in the header type, it implicitly means subgroup 0 subgroup_id: header.subgroup_id.unwrap_or(0), - priority: header.publisher_priority, - })?; + // When priority is not present (NoPriority header types), default to 0 + priority: header.publisher_priority.unwrap_or(0), + // Preserve the incoming header type for forwarding + header_type: Some(header.header_type), + }); + // Always put writer back, even on error, to avoid losing it self.writer = Some(subgroups.into()); - Ok(writer) + result } pub fn datagram(&mut self, datagram: data::Datagram) -> Result<(), ServeError> { @@ -224,23 +198,39 @@ impl SubscribeRecv { TrackWriterMode::Track(track) => { // convert Track -> Datagrams writer, write, then put Datagrams back let mut datagrams = track.datagrams()?; + // Determine status from datagram type or explicit status field + let status = if datagram.datagram_type.is_end_of_group() { + Some(crate::data::ObjectStatus::EndOfGroup) + } else { + datagram.status + }; datagrams.write(serve::Datagram { group_id: datagram.group_id, object_id: datagram.object_id.unwrap_or(0), - priority: datagram.publisher_priority, + // When priority is not present (NoPriority datagram types), default to 0 + priority: datagram.publisher_priority.unwrap_or(0), payload: datagram.payload.unwrap_or_default(), extension_headers: datagram.extension_headers.unwrap_or_default(), + status, })?; self.writer = Some(TrackWriterMode::Datagrams(datagrams)); Ok(()) } TrackWriterMode::Datagrams(mut datagrams) => { + // Determine status from datagram type or explicit status field + let status = if datagram.datagram_type.is_end_of_group() { + Some(crate::data::ObjectStatus::EndOfGroup) + } else { + datagram.status + }; datagrams.write(serve::Datagram { group_id: datagram.group_id, object_id: datagram.object_id.unwrap_or(0), - priority: datagram.publisher_priority, + // When priority is not present (NoPriority datagram types), default to 0 + priority: datagram.publisher_priority.unwrap_or(0), payload: datagram.payload.unwrap_or_default(), extension_headers: datagram.extension_headers.unwrap_or_default(), + status, })?; self.writer = Some(TrackWriterMode::Datagrams(datagrams)); Ok(()) diff --git a/moq-transport/src/session/subscribe_namespace.rs b/moq-transport/src/session/subscribe_namespace.rs new file mode 100644 index 00000000..a89d7ecc --- /dev/null +++ b/moq-transport/src/session/subscribe_namespace.rs @@ -0,0 +1,143 @@ +use std::ops; + +use crate::coding::TrackNamespace; +use crate::watch::State; +use crate::{message, serve::ServeError}; + +use super::Subscriber; + +#[derive(Debug, Clone)] +pub struct SubscribeNsInfo { + pub request_id: u64, + pub namespace_prefix: TrackNamespace, +} + +struct SubscribeNsState { + ok: bool, + closed: Result<(), ServeError>, +} + +impl Default for SubscribeNsState { + fn default() -> Self { + Self { + ok: false, + closed: Ok(()), + } + } +} + +/// Represents an outbound SUBSCRIBE_NAMESPACE request (subscriber side). +/// When dropped, sends UNSUBSCRIBE_NAMESPACE to the peer. +#[must_use = "sends UNSUBSCRIBE_NAMESPACE on drop"] +pub struct SubscribeNs { + subscriber: Subscriber, + state: State, + + pub info: SubscribeNsInfo, +} + +impl SubscribeNs { + pub(super) fn new( + mut subscriber: Subscriber, + request_id: u64, + namespace_prefix: TrackNamespace, + ) -> (SubscribeNs, SubscribeNsRecv) { + let info = SubscribeNsInfo { + request_id, + namespace_prefix: namespace_prefix.clone(), + }; + + subscriber.send_message(message::SubscribeNamespace::new( + request_id, + namespace_prefix, + 1, + )); + + let (send, recv) = State::default().split(); + + let send = Self { + subscriber, + info, + state: send, + }; + let recv = SubscribeNsRecv { state: recv }; + + (send, recv) + } + + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } + + pub async fn ok(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + if state.ok { + return Ok(()); + } + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } +} + +impl Drop for SubscribeNs { + fn drop(&mut self) { + // In draft-16, SUBSCRIBE_NAMESPACE uses its own bidirectional stream. + // Closing the stream implicitly unsubscribes. + } +} + +impl ops::Deref for SubscribeNs { + type Target = SubscribeNsInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +pub(super) struct SubscribeNsRecv { + state: State, +} + +impl SubscribeNsRecv { + pub fn recv_ok(&mut self) -> Result<(), ServeError> { + if let Some(mut state) = self.state.lock_mut() { + if state.ok { + return Err(ServeError::Duplicate); + } + + state.ok = true; + } + + Ok(()) + } + + pub fn recv_error(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Done)?; + state.closed = Err(err); + + Ok(()) + } +} diff --git a/moq-transport/src/session/subscribe_namespace_received.rs b/moq-transport/src/session/subscribe_namespace_received.rs new file mode 100644 index 00000000..60b3a333 --- /dev/null +++ b/moq-transport/src/session/subscribe_namespace_received.rs @@ -0,0 +1,148 @@ +use std::ops; + +use crate::coding::{ReasonPhrase, TrackNamespace}; +use crate::watch::State; +use crate::{message, serve::ServeError}; + +use super::Publisher; + +#[derive(Debug, Clone)] +pub struct SubscribeNamespaceReceivedInfo { + pub request_id: u64, + pub namespace_prefix: TrackNamespace, +} + +struct SubscribeNamespaceReceivedState { + closed: Result<(), ServeError>, +} + +impl Default for SubscribeNamespaceReceivedState { + fn default() -> Self { + Self { closed: Ok(()) } + } +} + +#[must_use = "sends SUBSCRIBE_NAMESPACE_ERROR on drop if not accepted"] +pub struct SubscribeNamespaceReceived { + publisher: Publisher, + state: State, + pub info: SubscribeNamespaceReceivedInfo, + ok: bool, +} + +impl SubscribeNamespaceReceived { + pub(super) fn new( + publisher: Publisher, + request_id: u64, + namespace_prefix: TrackNamespace, + ) -> (Self, SubscribeNamespaceReceivedRecv) { + let info = SubscribeNamespaceReceivedInfo { + request_id, + namespace_prefix: namespace_prefix.clone(), + }; + + let (send, recv) = State::default().split(); + + let send = Self { + publisher, + info, + state: send, + ok: false, + }; + + let recv = SubscribeNamespaceReceivedRecv { + state: recv, + namespace_prefix, + }; + + (send, recv) + } + + pub fn ok(&mut self) -> Result<(), ServeError> { + if self.ok { + return Err(ServeError::Duplicate); + } + + self.publisher.send_message(message::RequestOk { + id: self.info.request_id, + params: Default::default(), + }); + + self.ok = true; + + Ok(()) + } + + pub fn reject(mut self, error_code: u64, reason: &str) -> Result<(), ServeError> { + self.publisher.send_message(message::RequestError { + id: self.info.request_id, + error_code, + retry_interval: 0, + reason_phrase: ReasonPhrase(reason.to_string()), + }); + + self.ok = true; + + Ok(()) + } + + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } +} + +impl ops::Deref for SubscribeNamespaceReceived { + type Target = SubscribeNamespaceReceivedInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +impl Drop for SubscribeNamespaceReceived { + fn drop(&mut self) { + if self.ok { + return; + } + + self.publisher.send_message(message::RequestError { + id: self.info.request_id, + error_code: ServeError::NotFound.code(), + retry_interval: 0, + reason_phrase: ReasonPhrase("SUBSCRIBE_NAMESPACE not handled".to_string()), + }); + } +} + +pub(super) struct SubscribeNamespaceReceivedRecv { + state: State, + namespace_prefix: TrackNamespace, +} + +impl SubscribeNamespaceReceivedRecv { + pub fn namespace_prefix(&self) -> &TrackNamespace { + &self.namespace_prefix + } + + pub fn recv_unsubscribe(&mut self) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + if let Some(mut state) = state.into_mut() { + state.closed = Err(ServeError::Cancel); + } + + Ok(()) + } +} diff --git a/moq-transport/src/session/subscribed.rs b/moq-transport/src/session/subscribed.rs index e87961fc..fc17c129 100644 --- a/moq-transport/src/session/subscribed.rs +++ b/moq-transport/src/session/subscribed.rs @@ -20,7 +20,20 @@ struct SubscribedState { closed: Result<(), ServeError>, } +impl Default for SubscribedState { + fn default() -> Self { + Self { + largest_location: None, + closed: Ok(()), + } + } +} + impl SubscribedState { + fn is_closed(&self) -> bool { + self.closed.is_err() + } + fn update_largest_location(&mut self, group_id: u64, object_id: u64) -> Result<(), ServeError> { if let Some(current_largest_location) = self.largest_location { let update_largest_location = Location::new(group_id, object_id); @@ -33,15 +46,6 @@ impl SubscribedState { } } -impl Default for SubscribedState { - fn default() -> Self { - Self { - largest_location: None, - closed: Ok(()), - } - } -} - pub struct Subscribed { /// The sessions Publisher manager, used to send control messages, /// create new QUIC streams, and send datagrams @@ -66,7 +70,7 @@ impl Subscribed { msg: message::Subscribe, mlog: Option>>, ) -> (Self, SubscribedRecv) { - let (send, recv) = State::default().split(); + let (send, recv) = State::new(SubscribedState::default()).split(); let info = SubscribeInfo::new_from_subscribe(&msg); let send = Self { publisher, @@ -102,14 +106,12 @@ impl Subscribed { // Send SubscribeOk using send_message_and_wait to ensure it is sent at least to the QUIC stack before // we start serving the track. If a subscriber gets the stream before SubscribeOk // then they won't recognize the track_alias in the stream header. + let track_alias = self.publisher.next_track_alias(); self.publisher .send_message_and_wait(message::SubscribeOk { id: self.info.id, - track_alias: self.info.id, // use subscription id as track alias - expires: 0, // TODO SLG - group_order: message::GroupOrder::Descending, // TODO: resolve correct value from publisher / subscriber prefs - content_exists: largest_location.is_some(), - largest_location, + track_alias, + track_extensions: Default::default(), params: Default::default(), }) .await; @@ -120,8 +122,12 @@ impl Subscribed { match track.mode().await? { // TODO cancel track/datagrams on closed TrackReaderMode::Stream(_stream) => panic!("deprecated"), - TrackReaderMode::Subgroups(subgroups) => self.serve_subgroups(subgroups).await, - TrackReaderMode::Datagrams(datagrams) => self.serve_datagrams(datagrams).await, + TrackReaderMode::Subgroups(subgroups) => { + self.serve_subgroups(subgroups, track_alias).await + } + TrackReaderMode::Datagrams(datagrams) => { + self.serve_datagrams(datagrams, track_alias).await + } } } @@ -178,9 +184,10 @@ impl Drop for Subscribed { reason: ReasonPhrase(err.to_string()), }); } else { - self.publisher.send_message(message::SubscribeError { + self.publisher.send_message(message::RequestError { id: self.info.id, error_code: err.code(), + retry_interval: 0, reason_phrase: ReasonPhrase(err.to_string()), }); }; @@ -191,6 +198,7 @@ impl Subscribed { async fn serve_subgroups( &mut self, mut subgroups: serve::SubgroupsReader, + track_alias: u64, ) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); let mut done: Option> = None; @@ -199,12 +207,19 @@ impl Subscribed { tokio::select! { res = subgroups.next(), if done.is_none() => match res { Ok(Some(subgroup)) => { + // Use preserved header type if available, otherwise default to SubgroupIdExt + let header_type = subgroup.info.header_type.unwrap_or(data::StreamHeaderType::SubgroupIdExt); + let subgroup_id = if header_type.has_subgroup_id() { + Some(subgroup.subgroup_id) + } else { + None + }; let header = data::SubgroupHeader { - header_type: data::StreamHeaderType::SubgroupIdExt, // SubGroupId = Yes, Extensions = Yes, ContainsEndOfGroup = No - track_alias: self.info.id, // use subscription id as track_alias + header_type, + track_alias, group_id: subgroup.group_id, - subgroup_id: Some(subgroup.subgroup_id), - publisher_priority: subgroup.priority, + subgroup_id, + publisher_priority: Some(subgroup.priority), }; let publisher = self.publisher.clone(); @@ -251,7 +266,7 @@ impl Subscribed { let mut writer = Writer::new(send_stream); log::debug!( - "[PUBLISHER] serve_subgroup: sending header - track_alias={}, group_id={}, subgroup_id={:?}, priority={}, header_type={:?}", + "[PUBLISHER] serve_subgroup: sending header - track_alias={}, group_id={}, subgroup_id={:?}, priority={:?}, header_type={:?}", header.track_alias, header.group_id, header.subgroup_id, @@ -271,47 +286,78 @@ impl Subscribed { } } + let has_extension_headers = header.header_type.has_extension_headers(); let mut object_count = 0; while let Some(mut subgroup_object_reader) = subgroup_reader.next().await? { - let subgroup_object = data::SubgroupObjectExt { - object_id_delta: 0, // before delta logic, used to be subgroup_object_reader.object_id, - extension_headers: subgroup_object_reader.extension_headers.clone(), // Pass through extension headers - payload_length: subgroup_object_reader.size, - status: if subgroup_object_reader.size == 0 { - // Only set status if payload length is zero - Some(subgroup_object_reader.status) - } else { - None - }, - }; + if state.lock().is_closed() { + log::debug!( + "[PUBLISHER] serve_subgroup: subscription cancelled, stopping (group_id={}, subgroup_id={:?}, {} objects sent)", + subgroup_reader.group_id, + subgroup_reader.subgroup_id, + object_count + ); + return Ok(()); + } - log::debug!( - "[PUBLISHER] serve_subgroup: sending object #{} - object_id={}, object_id_delta={}, payload_length={}, status={:?}, extension_headers={:?}", - object_count + 1, - subgroup_object_reader.object_id, - subgroup_object.object_id_delta, - subgroup_object.payload_length, - subgroup_object.status, - subgroup_object.extension_headers - ); + // Encode object based on header type - must match what receiver expects + if has_extension_headers { + let subgroup_object = data::SubgroupObjectExt { + object_id_delta: 0, + extension_headers: subgroup_object_reader.extension_headers.clone(), + payload_length: subgroup_object_reader.size, + status: if subgroup_object_reader.size == 0 { + Some(subgroup_object_reader.status) + } else { + None + }, + }; - writer.encode(&subgroup_object).await?; + log::debug!( + "[PUBLISHER] serve_subgroup: sending object #{} (ext) - object_id={}, payload_length={}, status={:?}, extension_headers={:?}", + object_count + 1, + subgroup_object_reader.object_id, + subgroup_object.payload_length, + subgroup_object.status, + subgroup_object.extension_headers + ); - // Log subgroup object created/sent - if let Some(ref mlog) = mlog { - if let Ok(mut mlog_guard) = mlog.lock() { - let time = mlog_guard.elapsed_ms(); - let stream_id = 0; // TODO: Placeholder, need actual QUIC stream ID - let event = mlog::subgroup_object_ext_created( - time, - stream_id, - subgroup_reader.group_id, - subgroup_reader.subgroup_id, - subgroup_object_reader.object_id, - &subgroup_object, - ); - let _ = mlog_guard.add_event(event); + writer.encode(&subgroup_object).await?; + + if let Some(ref mlog) = mlog { + if let Ok(mut mlog_guard) = mlog.lock() { + let time = mlog_guard.elapsed_ms(); + let stream_id = 0; + let event = mlog::subgroup_object_ext_created( + time, + stream_id, + subgroup_reader.group_id, + subgroup_reader.subgroup_id, + subgroup_object_reader.object_id, + &subgroup_object, + ); + let _ = mlog_guard.add_event(event); + } } + } else { + let subgroup_object = data::SubgroupObject { + object_id_delta: 0, + payload_length: subgroup_object_reader.size, + status: if subgroup_object_reader.size == 0 { + Some(subgroup_object_reader.status) + } else { + None + }, + }; + + log::debug!( + "[PUBLISHER] serve_subgroup: sending object #{} - object_id={}, payload_length={}, status={:?}", + object_count + 1, + subgroup_object_reader.object_id, + subgroup_object.payload_length, + subgroup_object.status + ); + + writer.encode(&subgroup_object).await?; } state @@ -325,6 +371,13 @@ impl Subscribed { let mut chunks_sent = 0; let mut bytes_sent = 0; while let Some(chunk) = subgroup_object_reader.read().await? { + if state.lock().is_closed() { + log::debug!( + "[PUBLISHER] serve_subgroup: subscription cancelled during payload transfer" + ); + return Ok(()); + } + log::trace!( "[PUBLISHER] serve_subgroup: sending payload chunk #{} for object #{} ({} bytes)", chunks_sent + 1, @@ -358,12 +411,20 @@ impl Subscribed { async fn serve_datagrams( &mut self, mut datagrams: serve::DatagramsReader, + track_alias: u64, ) -> Result<(), SessionError> { log::debug!("[PUBLISHER] serve_datagrams: starting"); let mut datagram_count = 0; while let Some(datagram) = datagrams.read().await? { - // Determine datagram type based on extension headers presence + if self.state.lock().is_closed() { + log::debug!( + "[PUBLISHER] serve_datagrams: subscription cancelled, stopping ({} datagrams sent)", + datagram_count + ); + return Ok(()); + } + let has_extension_headers = !datagram.extension_headers.is_empty(); let datagram_type = if has_extension_headers { data::DatagramType::ObjectIdPayloadExt @@ -373,10 +434,10 @@ impl Subscribed { let encoded_datagram = data::Datagram { datagram_type, - track_alias: self.info.id, // use subscription id as track_alias + track_alias, group_id: datagram.group_id, object_id: Some(datagram.object_id), - publisher_priority: datagram.priority, + publisher_priority: Some(datagram.priority), extension_headers: if has_extension_headers { Some(datagram.extension_headers.clone()) } else { @@ -395,7 +456,7 @@ impl Subscribed { encoded_datagram.encode(&mut buffer)?; log::debug!( - "[PUBLISHER] serve_datagrams: sending datagram #{} - track_alias={}, group_id={}, object_id={}, priority={}, payload_len={}, extension_headers={:?}, total_encoded_len={}", + "[PUBLISHER] serve_datagrams: sending datagram #{} - track_alias={}, group_id={}, object_id={}, priority={:?}, payload_len={}, extension_headers={:?}, total_encoded_len={}", datagram_count + 1, encoded_datagram.track_alias, encoded_datagram.group_id, diff --git a/moq-transport/src/session/subscriber.rs b/moq-transport/src/session/subscriber.rs index 49c59e8d..7dd4b77c 100644 --- a/moq-transport/src/session/subscriber.rs +++ b/moq-transport/src/session/subscriber.rs @@ -17,41 +17,40 @@ use crate::{ use crate::watch::Queue; -use super::{Announced, AnnouncedRecv, Reader, Session, SessionError, Subscribe, SubscribeRecv}; +use super::{ + PublishNamespaceReceived, PublishNamespaceReceivedRecv, PublishReceived, PublishReceivedRecv, + Reader, Session, SessionError, Subscribe, SubscribeNs, SubscribeNsRecv, SubscribeRecv, +}; // Default timeout for waiting for subscribe aliases to become available via SUBSCRIBE_OK (1 second) const DEFAULT_ALIAS_WAIT_TIME_MS: u64 = 1000; -// TODO remove Clone. #[derive(Clone)] pub struct Subscriber { - /// The currently active inbound announces, keyed by namespace. - announced: Arc>>, + publish_namespaces_received: Arc>>, + + publish_namespace_received_queue: Queue, - /// Queue of announced namespaces we have received from the Publisher, waiting to be processed. - announced_queue: Queue, + subscribe_namespaces: Arc>>, - /// The currently active outbound subscribes, keyed by request id. subscribes: Arc>>, - /// Map of track alias to subscription id for quick lookup when receiving streams/datagrams. subscribe_alias_map: Arc>>, - /// Notify when subscribe alias map is updated subscribe_alias_notify: Arc, - /// The queue we will write any outbound control messages we want to send, the session run_send task - /// will process the queue and send the message on the control stream. + publishes_received: Arc>>, + + publish_received_queue: Queue, + + publish_alias_map: Arc>>, + + publish_alias_notify: Arc, + outgoing: Queue, - /// When we need a new Request Id for sending a request, we can get it from here. Note: The instance - /// of AtomicU64 is shared with the Subscriber, so the session uses unique request ids for all requests - /// generated. Note: If we initiated the QUIC connection then request id's start at 0 and increment by 2 - /// for each request (even numbers). If we accepted an inbound QUIC connection then request id's start at 1 and - /// increment by 2 for each request (odd numbers). next_requestid: Arc, - /// Optional mlog writer for logging transport events mlog: Option>>, } @@ -62,14 +61,19 @@ impl Subscriber { mlog: Option>>, ) -> Self { Self { - announced: Default::default(), - announced_queue: Default::default(), + publish_namespaces_received: Default::default(), + publish_namespace_received_queue: Default::default(), + subscribe_namespaces: Default::default(), subscribes: Default::default(), subscribe_alias_map: Default::default(), + subscribe_alias_notify: Arc::new(Notify::new()), + publishes_received: Default::default(), + publish_received_queue: Default::default(), + publish_alias_map: Default::default(), + publish_alias_notify: Arc::new(Notify::new()), outgoing, next_requestid, mlog, - subscribe_alias_notify: Arc::new(Notify::new()), } } @@ -85,13 +89,16 @@ impl Subscriber { Ok((session, subscriber)) } - /// Wait for the next announced namespace from the publisher, if any. - pub async fn announced(&mut self) -> Option { - self.announced_queue.pop().await + pub async fn publish_ns_recvd(&mut self) -> Option { + self.publish_namespace_received_queue.pop().await + } + + pub async fn publish_received(&mut self) -> Option { + self.publish_received_queue.pop().await } /// Get the current next request id to use and increment the value for by 2 for the next request - fn get_next_request_id(&self) -> u64 { + pub fn get_next_request_id(&self) -> u64 { self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed) } @@ -120,45 +127,49 @@ impl Subscriber { send.closed().await } + pub fn subscribe_ns( + &mut self, + namespace_prefix: TrackNamespace, + ) -> Result { + let request_id = self.get_next_request_id(); + + let mut subscribe_namespaces = self.subscribe_namespaces.lock().unwrap(); + let entry = match subscribe_namespaces.entry(request_id) { + hash_map::Entry::Occupied(_) => return Err(ServeError::Duplicate), + hash_map::Entry::Vacant(entry) => entry, + }; + + let (send, recv) = SubscribeNs::new(self.clone(), request_id, namespace_prefix); + entry.insert(recv); + + Ok(send) + } + /// Send a message to the publisher via the control stream. - pub(super) fn send_message>(&mut self, msg: M) { + pub fn send_message>(&mut self, msg: M) { let msg = msg.into(); // Remove our entry on terminal state. - match &msg { - message::Subscriber::PublishNamespaceCancel(msg) => { - self.drop_publish_namespace(&msg.track_namespace) - } - // TODO SLG - there is no longer a namespace in the error, need to map via request id - message::Subscriber::PublishNamespaceError(_msg) => {} // Not implemented yet - need request id mapping - _ => {} + if let message::Subscriber::PublishNamespaceCancel(msg) = &msg { + self.drop_publish_namespace(&msg.track_namespace) } // TODO report dropped messages? let _ = self.outgoing.push(msg.into()); } - /// Receive a message from the publisher via the control stream. pub(super) fn recv_message(&mut self, msg: message::Publisher) -> Result<(), SessionError> { let res = match &msg { message::Publisher::PublishNamespace(msg) => self.recv_publish_namespace(msg), - message::Publisher::PublishNamespaceDone(msg) => self.recv_publish_namespace_done(msg), - message::Publisher::Publish(_msg) => Err(SessionError::unimplemented("PUBLISH")), + message::Publisher::PublishNamespaceDone(msg) => self.recv_publish_ns_done(msg), + message::Publisher::Namespace(msg) => self.recv_namespace(msg), + message::Publisher::Publish(msg) => self.recv_publish(msg), message::Publisher::PublishDone(msg) => self.recv_publish_done(msg), message::Publisher::SubscribeOk(msg) => self.recv_subscribe_ok(msg), - message::Publisher::SubscribeError(msg) => self.recv_subscribe_error(msg), + message::Publisher::RequestError(msg) => self.recv_request_error(msg), message::Publisher::TrackStatusOk(msg) => self.recv_track_status_ok(msg), - message::Publisher::TrackStatusError(_msg) => { - Err(SessionError::unimplemented("TRACK_STATUS_ERROR")) - } message::Publisher::FetchOk(_msg) => Err(SessionError::unimplemented("FETCH_OK")), - message::Publisher::FetchError(_msg) => Err(SessionError::unimplemented("FETCH_ERROR")), - message::Publisher::SubscribeNamespaceOk(_msg) => { - Err(SessionError::unimplemented("SUBSCRIBE_NAMESPACE_OK")) - } - message::Publisher::SubscribeNamespaceError(_msg) => { - Err(SessionError::unimplemented("SUBSCRIBE_NAMESPACE_ERROR")) - } + message::Publisher::RequestOk(msg) => self.recv_request_ok(msg), }; if let Err(SessionError::Serve(err)) = res { @@ -169,23 +180,24 @@ impl Subscriber { res } - /// Handle the reception of a PublishNamespace message from the publisher. fn recv_publish_namespace( &mut self, msg: &message::PublishNamespace, ) -> Result<(), SessionError> { - let mut announces = self.announced.lock().unwrap(); + let mut entries = self.publish_namespaces_received.lock().unwrap(); - // Check for duplicate namespace announcement - let entry = match announces.entry(msg.track_namespace.clone()) { + let entry = match entries.entry(msg.track_namespace.clone()) { hash_map::Entry::Occupied(_) => return Err(SessionError::Duplicate), hash_map::Entry::Vacant(entry) => entry, }; - // Create the announced namespace and insert it into our map of active announces, and the announced queue. - let (announced, recv) = Announced::new(self.clone(), msg.id, msg.track_namespace.clone()); - if let Err(announced) = self.announced_queue.push(announced) { - announced.close(ServeError::Cancel)?; + let (publish_ns_received, recv) = + PublishNamespaceReceived::new(self.clone(), msg.id, msg.track_namespace.clone()); + if let Err(publish_ns_received) = self + .publish_namespace_received_queue + .push(publish_ns_received) + { + publish_ns_received.close(ServeError::Cancel)?; return Ok(()); } entry.insert(recv); @@ -193,14 +205,55 @@ impl Subscriber { Ok(()) } - /// Handle the reception of a PublishNamespaceDone message from the publisher. - fn recv_publish_namespace_done( + fn recv_publish_ns_done( &mut self, msg: &message::PublishNamespaceDone, ) -> Result<(), SessionError> { - if let Some(announce) = self.announced.lock().unwrap().remove(&msg.track_namespace) { - announce.recv_unannounce()?; + if let Some(entry) = self + .publish_namespaces_received + .lock() + .unwrap() + .remove(&msg.track_namespace) + { + entry.recv_done()?; + } + + Ok(()) + } + + /// Handle NAMESPACE message (draft-16) - relay forwards this in response to SUBSCRIBE_NAMESPACE + fn recv_namespace(&mut self, msg: &message::Namespace) -> Result<(), SessionError> { + log::info!( + "received NAMESPACE for {:?} (request_id={})", + msg.track_namespace, + msg.id + ); + // TODO: Implement proper handling - notify the SUBSCRIBE_NAMESPACE handler + // For now, just log and accept + Ok(()) + } + + fn recv_publish(&mut self, msg: &message::Publish) -> Result<(), SessionError> { + let mut entries = self.publishes_received.lock().unwrap(); + + let entry = match entries.entry(msg.id) { + hash_map::Entry::Occupied(_) => return Err(SessionError::Duplicate), + hash_map::Entry::Vacant(entry) => entry, + }; + + let (publish_received, recv) = PublishReceived::new(self.clone(), msg); + + self.publish_alias_map + .lock() + .unwrap() + .insert(msg.track_alias, msg.id); + self.publish_alias_notify.notify_waiters(); + + if let Err(publish_received) = self.publish_received_queue.push(publish_received) { + publish_received.close(ServeError::Cancel)?; + return Ok(()); } + entry.insert(recv); Ok(()) } @@ -240,19 +293,45 @@ impl Subscriber { } } - /// Handle the reception of a SubscribeError message from the publisher. - fn recv_subscribe_error(&mut self, msg: &message::SubscribeError) -> Result<(), SessionError> { + fn recv_request_ok(&mut self, msg: &message::RequestOk) -> Result<(), SessionError> { + if let Some(subscribe_ns) = self.subscribe_namespaces.lock().unwrap().get_mut(&msg.id) { + subscribe_ns.recv_ok()?; + return Ok(()); + } + + log::warn!( + "[SUBSCRIBER] recv_request_ok: request id {} not found", + msg.id + ); + Ok(()) + } + + fn recv_request_error(&mut self, msg: &message::RequestError) -> Result<(), SessionError> { if let Some(subscribe) = self.remove_subscribe(msg.id) { subscribe.error(ServeError::Closed(msg.error_code))?; + return Ok(()); + } + + if let Some(subscribe_ns) = self.subscribe_namespaces.lock().unwrap().remove(&msg.id) { + subscribe_ns.recv_error(ServeError::Closed(msg.error_code))?; + return Ok(()); } + log::warn!( + "[SUBSCRIBER] recv_request_error: request id {} not found", + msg.id + ); Ok(()) } - /// Handle the reception of a PublishDone message from the publisher. fn recv_publish_done(&mut self, msg: &message::PublishDone) -> Result<(), SessionError> { if let Some(subscribe) = self.remove_subscribe(msg.id) { subscribe.error(ServeError::Closed(msg.status_code))?; + return Ok(()); + } + + if let Some(mut publish_recv) = self.remove_publish_received(msg.id) { + publish_recv.recv_done()?; } Ok(()) @@ -266,23 +345,32 @@ impl Subscriber { Ok(()) } - /// Remove an announced namespace from our map of active announces. fn drop_publish_namespace(&mut self, namespace: &TrackNamespace) { - self.announced.lock().unwrap().remove(namespace); + self.publish_namespaces_received + .lock() + .unwrap() + .remove(namespace); + } + + fn remove_publish_received(&mut self, id: u64) -> Option { + if let Some(publish_recv) = self.publishes_received.lock().unwrap().remove(&id) { + if let Some(track_alias) = publish_recv.track_alias() { + self.publish_alias_map.lock().unwrap().remove(&track_alias); + } + Some(publish_recv) + } else { + None + } } - /// Get a subscribe id by track alias, waiting up to the specified timeout if not present. - /// If timeout_ms is None, only check if already present and return None if not. async fn get_subscribe_id_by_alias( &self, track_alias: u64, timeout_ms: Option, ) -> Option { - // If no timeout specified, don't wait let timeout_ms = match timeout_ms { Some(ms) => ms, None => { - // Just check once return self .subscribe_alias_map .lock() @@ -292,14 +380,11 @@ impl Subscriber { } }; - // Wait for it to appear, checking after each notification let timeout_duration = Duration::from_millis(timeout_ms); tokio::time::timeout(timeout_duration, async { loop { - // Register for notification before checking map let notified = self.subscribe_alias_notify.notified(); - // Check Map for alias if let Some(id) = self .subscribe_alias_map .lock() @@ -310,7 +395,45 @@ impl Subscriber { return id; } - // Alias not present yet, wait for notification + notified.await; + } + }) + .await + .ok() + } + + async fn get_publish_id_by_alias( + &self, + track_alias: u64, + timeout_ms: Option, + ) -> Option { + let timeout_ms = match timeout_ms { + Some(ms) => ms, + None => { + return self + .publish_alias_map + .lock() + .unwrap() + .get(&track_alias) + .cloned(); + } + }; + + let timeout_duration = Duration::from_millis(timeout_ms); + tokio::time::timeout(timeout_duration, async { + loop { + let notified = self.publish_alias_notify.notified(); + + if let Some(id) = self + .publish_alias_map + .lock() + .unwrap() + .get(&track_alias) + .cloned() + { + return id; + } + notified.await; } }) @@ -376,7 +499,6 @@ impl Subscriber { res } - /// Continue handling the reception of a new stream from the QUIC session. async fn recv_stream_inner( &mut self, reader: Reader, @@ -389,19 +511,28 @@ impl Subscriber { track_alias ); - // This is super silly, but I couldn't figure out a way to avoid the mutex guard across awaits. enum Writer { - //Fetch(serve::FetchWriter), Subgroup(serve::SubgroupWriter), } + // First check both maps WITHOUT waiting - this is the fast path for subsequent groups + // where the alias mapping is already established + let (subscribe_id_immediate, publish_id_immediate) = { + let subscribe_id = self.get_subscribe_id_by_alias(track_alias, None).await; + let publish_id = self.get_publish_id_by_alias(track_alias, None).await; + (subscribe_id, publish_id) + }; + + log::debug!( + "[SUBSCRIBER] recv_stream_inner: track_alias={}, subscribe_id_immediate={:?}, publish_id_immediate={:?}", + track_alias, subscribe_id_immediate, publish_id_immediate + ); + + // Determine which path to use, waiting only if neither map has the alias yet let writer = { - // Look up the subscribe id for this track alias - if let Some(subscribe_id) = self - .get_subscribe_id_by_alias(track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)) - .await - { - // Look up the subscribe by id + if let Some(subscribe_id) = subscribe_id_immediate { + // Found in subscribe map immediately + log::debug!("[SUBSCRIBER] recv_stream_inner: using SUBSCRIBE path (immediate)"); let mut subscribes = self.subscribes.lock().unwrap(); let subscribe = subscribes.get_mut(&subscribe_id).ok_or_else(|| { ServeError::not_found_ctx(format!( @@ -410,10 +541,37 @@ impl Subscriber { )) })?; - // Create the appropriate writer based on the stream header type if stream_header.header_type.is_subgroup() { - log::trace!("[SUBSCRIBER] recv_stream_inner: creating subgroup writer"); - Writer::Subgroup(subscribe.subgroup(stream_header.subgroup_header.unwrap())?) + log::trace!( + "[SUBSCRIBER] recv_stream_inner: creating subgroup writer from subscribe" + ); + Writer::Subgroup( + subscribe.subgroup(stream_header.subgroup_header.clone().unwrap())?, + ) + } else { + return Err(SessionError::Serve(ServeError::internal_ctx(format!( + "unsupported stream header type={}", + stream_header.header_type + )))); + } + } else if let Some(publish_id) = publish_id_immediate { + // Found in publish map immediately + log::debug!("[SUBSCRIBER] recv_stream_inner: using PUBLISH path (immediate)"); + let mut publishes = self.publishes_received.lock().unwrap(); + let publish_recv = publishes.get_mut(&publish_id).ok_or_else(|| { + ServeError::not_found_ctx(format!( + "publish_id={} not found for track_alias={}", + publish_id, track_alias + )) + })?; + + if stream_header.header_type.is_subgroup() { + log::trace!( + "[SUBSCRIBER] recv_stream_inner: creating subgroup writer from publish" + ); + Writer::Subgroup( + publish_recv.subgroup(stream_header.subgroup_header.clone().unwrap())?, + ) } else { return Err(SessionError::Serve(ServeError::internal_ctx(format!( "unsupported stream header type={}", @@ -421,16 +579,69 @@ impl Subscriber { )))); } } else { - return Err(SessionError::Serve(ServeError::not_found_ctx(format!( - "subscription track_alias={} not found", + // Not found in either map - wait for either to become available + // This only happens for the first stream before control messages establish the mapping + log::debug!( + "[SUBSCRIBER] recv_stream_inner: track_alias={} NOT FOUND in either map, WAITING for alias mapping", track_alias - )))); + ); + + // Race both lookups with timeout + let subscribe_fut = self.get_subscribe_id_by_alias(track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)); + let publish_fut = self.get_publish_id_by_alias(track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)); + + tokio::select! { + Some(subscribe_id) = subscribe_fut => { + let mut subscribes = self.subscribes.lock().unwrap(); + let subscribe = subscribes.get_mut(&subscribe_id).ok_or_else(|| { + ServeError::not_found_ctx(format!( + "subscribe_id={} not found for track_alias={}", + subscribe_id, track_alias + )) + })?; + + if stream_header.header_type.is_subgroup() { + Writer::Subgroup( + subscribe.subgroup(stream_header.subgroup_header.clone().unwrap())?, + ) + } else { + return Err(SessionError::Serve(ServeError::internal_ctx(format!( + "unsupported stream header type={}", + stream_header.header_type + )))); + } + } + Some(publish_id) = publish_fut => { + let mut publishes = self.publishes_received.lock().unwrap(); + let publish_recv = publishes.get_mut(&publish_id).ok_or_else(|| { + ServeError::not_found_ctx(format!( + "publish_id={} not found for track_alias={}", + publish_id, track_alias + )) + })?; + + if stream_header.header_type.is_subgroup() { + Writer::Subgroup( + publish_recv.subgroup(stream_header.subgroup_header.clone().unwrap())?, + ) + } else { + return Err(SessionError::Serve(ServeError::internal_ctx(format!( + "unsupported stream header type={}", + stream_header.header_type + )))); + } + } + else => { + return Err(SessionError::Serve(ServeError::not_found_ctx(format!( + "subscription track_alias={} not found", + track_alias + )))); + } + } } }; - // Handle the stream based on the writer type match writer { - //Writer::Fetch(fetch) => Self::recv_fetch(fetch, reader).await?, Writer::Subgroup(subgroup_writer) => { log::trace!("[SUBSCRIBER] recv_stream_inner: receiving subgroup data"); Self::recv_subgroup(stream_header.header_type, subgroup_writer, reader, mlog) @@ -513,6 +724,20 @@ impl Subscriber { } } + // Check for Prior Object ID Gap (type 0x3E = 62) + if object.extension_headers.has(0x3E) { + log::info!( + "[SUBSCRIBER] recv_subgroup: object #{} contains PRIOR OBJECT ID GAP (type 0x3E)", + object_count + 1 + ); + if let Some(gap_ext) = object.extension_headers.get(0x3E) { + log::debug!( + "[SUBSCRIBER] recv_subgroup: prior object id gap details: {:?}", + gap_ext + ); + } + } + let obj_copy = object.clone(); ( object.payload_length, @@ -581,10 +806,12 @@ impl Subscriber { } } - // Pass extension headers through to the serve layer - // TODO SLG - object_id_delta and object status are still being ignored - - let mut object_writer = subgroup_writer.create(remaining_bytes, extension_headers)?; + // Pass extension headers and status through to the serve layer + let mut object_writer = subgroup_writer.create_with_status( + remaining_bytes, + extension_headers, + status.unwrap_or(crate::data::ObjectStatus::NormalObject), + )?; log::trace!( "[SUBSCRIBER] recv_subgroup: reading payload for object #{} ({} bytes)", object_count + 1, @@ -624,11 +851,27 @@ impl Subscriber { object_count += 1; } + // If the stream header type signals end-of-group, write an EndOfGroup marker + // This forwards the "stream end = group end" semantic to downstream subscribers + if stream_header_type.signals_end_of_group() { + log::debug!( + "[SUBSCRIBER] recv_subgroup: writing EndOfGroup marker (header_type={:?} signals EOG)", + stream_header_type + ); + if let Err(e) = subgroup_writer.end_of_group() { + log::warn!( + "[SUBSCRIBER] recv_subgroup: failed to write EndOfGroup marker: {}", + e + ); + } + } + log::info!( - "[SUBSCRIBER] recv_subgroup: completed subgroup (group_id={}, subgroup_id={}, {} objects received)", + "[SUBSCRIBER] recv_subgroup: completed subgroup (group_id={}, subgroup_id={}, {} objects received, eog={})", subgroup_writer.info.group_id, subgroup_writer.info.subgroup_id, - object_count + object_count, + stream_header_type.signals_end_of_group() ); Ok(()) @@ -682,17 +925,33 @@ impl Subscriber { ); } } + + // Check for Prior Object ID Gap (type 0x3E = 62) + if ext_headers.has(0x3E) { + log::info!( + "[SUBSCRIBER] recv_datagram: datagram contains PRIOR OBJECT ID GAP (type 0x3E)" + ); + if let Some(gap_ext) = ext_headers.get(0x3E) { + log::debug!( + "[SUBSCRIBER] recv_datagram: prior object id gap details: {:?}", + gap_ext + ); + } + } } - // Look up the subscribe id for this track alias - if let Some(subscribe_id) = self - .get_subscribe_id_by_alias(datagram.track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)) - .await - { - // Look up the subscribe by id + // Fast path: check both maps immediately WITHOUT waiting + // This allows datagrams to flow at full rate once alias mapping is established + let (subscribe_id_immediate, publish_id_immediate) = { + let subscribe_id = self.get_subscribe_id_by_alias(datagram.track_alias, None).await; + let publish_id = self.get_publish_id_by_alias(datagram.track_alias, None).await; + (subscribe_id, publish_id) + }; + + if let Some(subscribe_id) = subscribe_id_immediate { if let Some(subscribe) = self.subscribes.lock().unwrap().get_mut(&subscribe_id) { log::trace!( - "[SUBSCRIBER] recv_datagram: track_alias={}, group_id={}, object_id={}, publisher_priority={}, status={}, payload_length={}", + "[SUBSCRIBER] recv_datagram: track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", datagram.track_alias, datagram.group_id, datagram.object_id.unwrap_or(0), @@ -701,15 +960,63 @@ impl Subscriber { datagram.payload.as_ref().map_or(0, |p| p.len())); subscribe.datagram(datagram)?; } + } else if let Some(publish_id) = publish_id_immediate { + if let Some(publish_recv) = self.publishes_received.lock().unwrap().get_mut(&publish_id) + { + log::trace!( + "[SUBSCRIBER] recv_datagram from publish: track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", + datagram.track_alias, + datagram.group_id, + datagram.object_id.unwrap_or(0), + datagram.publisher_priority, + datagram.status.as_ref().map_or("None".to_string(), |s| format!("{:?}", s)), + datagram.payload.as_ref().map_or(0, |p| p.len())); + publish_recv.datagram(datagram)?; + } } else { - log::warn!( - "[SUBSCRIBER] recv_datagram: discarded due to unknown track_alias: track_alias={}, group_id={}, object_id={}, publisher_priority={}, status={}, payload_length={}", - datagram.track_alias, - datagram.group_id, - datagram.object_id.unwrap_or(0), - datagram.publisher_priority, - datagram.status.as_ref().map_or("None".to_string(), |s| format!("{:?}", s)), - datagram.payload.as_ref().map_or(0, |p| p.len())); + // Slow path: alias not found immediately, wait with timeout (only for first datagram) + let subscribe_fut = self.get_subscribe_id_by_alias(datagram.track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)); + let publish_fut = self.get_publish_id_by_alias(datagram.track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)); + + tokio::select! { + Some(subscribe_id) = subscribe_fut => { + if let Some(subscribe) = self.subscribes.lock().unwrap().get_mut(&subscribe_id) { + log::trace!( + "[SUBSCRIBER] recv_datagram (waited): track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", + datagram.track_alias, + datagram.group_id, + datagram.object_id.unwrap_or(0), + datagram.publisher_priority, + datagram.status.as_ref().map_or("None".to_string(), |s| format!("{:?}", s)), + datagram.payload.as_ref().map_or(0, |p| p.len())); + subscribe.datagram(datagram)?; + } + } + Some(publish_id) = publish_fut => { + if let Some(publish_recv) = self.publishes_received.lock().unwrap().get_mut(&publish_id) + { + log::trace!( + "[SUBSCRIBER] recv_datagram from publish (waited): track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", + datagram.track_alias, + datagram.group_id, + datagram.object_id.unwrap_or(0), + datagram.publisher_priority, + datagram.status.as_ref().map_or("None".to_string(), |s| format!("{:?}", s)), + datagram.payload.as_ref().map_or(0, |p| p.len())); + publish_recv.datagram(datagram)?; + } + } + else => { + log::warn!( + "[SUBSCRIBER] recv_datagram: discarded due to unknown track_alias: track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", + datagram.track_alias, + datagram.group_id, + datagram.object_id.unwrap_or(0), + datagram.publisher_priority, + datagram.status.as_ref().map_or("None".to_string(), |s| format!("{:?}", s)), + datagram.payload.as_ref().map_or(0, |p| p.len())); + } + } } Ok(()) diff --git a/moq-transport/src/session/track_status_requested.rs b/moq-transport/src/session/track_status_requested.rs index 4c9cf744..587610be 100644 --- a/moq-transport/src/session/track_status_requested.rs +++ b/moq-transport/src/session/track_status_requested.rs @@ -21,9 +21,10 @@ impl TrackStatusRequested { error_code: u64, error_message: &str, ) -> Result<(), SessionError> { - let status_error = message::TrackStatusError { + let status_error = message::RequestError { id: self.request_msg.id, error_code, + retry_interval: 0, reason_phrase: ReasonPhrase(error_message.to_string()), }; self.publisher.send_message(status_error); diff --git a/moq-transport/src/setup/auth_token.rs b/moq-transport/src/setup/auth_token.rs new file mode 100644 index 00000000..a1b22be5 --- /dev/null +++ b/moq-transport/src/setup/auth_token.rs @@ -0,0 +1,298 @@ +//! Authorization Token support for MOQT. +//! +//! This module provides support for authorization tokens as defined in the MOQT specification. +//! Tokens can be sent inline or referenced by alias to avoid retransmission of large tokens. + +use std::collections::HashMap; + +/// Authorization Token Types +/// +/// Defines how an authorization token is transmitted in messages. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[repr(u8)] +pub enum AuthTokenType { + /// No authorization token present + None = 0x0, + /// Authorization token sent inline + Inline = 0x1, + /// Authorization token referenced by alias + Alias = 0x2, + /// Authorization token cached with new alias + Store = 0x3, + /// Use previously stored token (DELETE is not allowed in CLIENT_SETUP) + UseAlias = 0x4, +} + +impl TryFrom for AuthTokenType { + type Error = (); + + fn try_from(value: u8) -> Result { + match value { + 0x0 => Ok(Self::None), + 0x1 => Ok(Self::Inline), + 0x2 => Ok(Self::Alias), + 0x3 => Ok(Self::Store), + 0x4 => Ok(Self::UseAlias), + _ => Err(()), + } + } +} + +/// An authorization token value +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct AuthToken { + /// The raw token bytes + pub token: Vec, + /// Optional alias for caching + pub alias: Option, +} + +impl AuthToken { + /// Create a new authorization token + pub fn new(token: Vec) -> Self { + Self { token, alias: None } + } + + /// Create a new authorization token with an alias for caching + pub fn with_alias(token: Vec, alias: u64) -> Self { + Self { + token, + alias: Some(alias), + } + } + + /// Check if the token is empty + pub fn is_empty(&self) -> bool { + self.token.is_empty() + } +} + +/// Authorization Token Cache +/// +/// Stores authorization tokens by their alias for efficient re-use across multiple messages. +/// The cache enforces a maximum size limit as negotiated during setup. +#[derive(Debug)] +pub struct AuthTokenCache { + /// Maximum number of tokens that can be cached + max_size: usize, + /// Cached tokens by alias + tokens: HashMap>, + /// Next available alias (for server-assigned aliases) + next_alias: u64, +} + +impl Default for AuthTokenCache { + fn default() -> Self { + Self::new(0) + } +} + +impl AuthTokenCache { + /// Create a new auth token cache with the specified maximum size + pub fn new(max_size: usize) -> Self { + Self { + max_size, + tokens: HashMap::new(), + next_alias: 0, + } + } + + /// Get the maximum cache size + pub fn max_size(&self) -> usize { + self.max_size + } + + /// Set the maximum cache size (typically from setup negotiation) + pub fn set_max_size(&mut self, max_size: usize) { + self.max_size = max_size; + } + + /// Get the current number of cached tokens + pub fn len(&self) -> usize { + self.tokens.len() + } + + /// Check if the cache is empty + pub fn is_empty(&self) -> bool { + self.tokens.is_empty() + } + + /// Check if the cache is at capacity + pub fn is_full(&self) -> bool { + self.tokens.len() >= self.max_size + } + + /// Store a token with the given alias + /// + /// Returns an error if: + /// - The cache is at capacity + /// - The alias is already in use + pub fn store(&mut self, alias: u64, token: Vec) -> Result<(), AuthTokenCacheError> { + if self.max_size == 0 { + return Err(AuthTokenCacheError::CacheDisabled); + } + if self.tokens.len() >= self.max_size { + return Err(AuthTokenCacheError::CacheOverflow); + } + if self.tokens.contains_key(&alias) { + return Err(AuthTokenCacheError::DuplicateAlias(alias)); + } + self.tokens.insert(alias, token); + Ok(()) + } + + /// Store a token with an auto-generated alias + /// + /// Returns the assigned alias, or an error if the cache is full + pub fn store_with_auto_alias(&mut self, token: Vec) -> Result { + if self.max_size == 0 { + return Err(AuthTokenCacheError::CacheDisabled); + } + if self.tokens.len() >= self.max_size { + return Err(AuthTokenCacheError::CacheOverflow); + } + + // Find next available alias + while self.tokens.contains_key(&self.next_alias) { + self.next_alias = self.next_alias.wrapping_add(1); + } + + let alias = self.next_alias; + self.tokens.insert(alias, token); + self.next_alias = self.next_alias.wrapping_add(1); + + Ok(alias) + } + + /// Get a token by its alias + pub fn get(&self, alias: u64) -> Option<&Vec> { + self.tokens.get(&alias) + } + + /// Remove a token by its alias + pub fn remove(&mut self, alias: u64) -> Option> { + self.tokens.remove(&alias) + } + + /// Clear all cached tokens + pub fn clear(&mut self) { + self.tokens.clear(); + } +} + +/// Errors that can occur when working with the auth token cache +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum AuthTokenCacheError { + /// The cache is disabled (max_size is 0) + CacheDisabled, + /// The cache is full and cannot accept more tokens + CacheOverflow, + /// The alias is already in use + DuplicateAlias(u64), + /// The alias was not found in the cache + UnknownAlias(u64), + /// The token is malformed + MalformedToken, + /// The token has expired + ExpiredToken, +} + +impl std::fmt::Display for AuthTokenCacheError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::CacheDisabled => write!(f, "authorization token cache is disabled"), + Self::CacheOverflow => write!(f, "authorization token cache is full"), + Self::DuplicateAlias(alias) => { + write!(f, "duplicate authorization token alias: {}", alias) + } + Self::UnknownAlias(alias) => { + write!(f, "unknown authorization token alias: {}", alias) + } + Self::MalformedToken => write!(f, "malformed authorization token"), + Self::ExpiredToken => write!(f, "expired authorization token"), + } + } +} + +impl std::error::Error for AuthTokenCacheError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_auth_token_type_conversion() { + assert_eq!(AuthTokenType::try_from(0u8), Ok(AuthTokenType::None)); + assert_eq!(AuthTokenType::try_from(1u8), Ok(AuthTokenType::Inline)); + assert_eq!(AuthTokenType::try_from(2u8), Ok(AuthTokenType::Alias)); + assert_eq!(AuthTokenType::try_from(3u8), Ok(AuthTokenType::Store)); + assert_eq!(AuthTokenType::try_from(4u8), Ok(AuthTokenType::UseAlias)); + assert!(AuthTokenType::try_from(5u8).is_err()); + } + + #[test] + fn test_auth_token() { + let token = AuthToken::new(vec![1, 2, 3, 4]); + assert!(!token.is_empty()); + assert!(token.alias.is_none()); + + let token_with_alias = AuthToken::with_alias(vec![5, 6, 7, 8], 42); + assert_eq!(token_with_alias.alias, Some(42)); + + let empty_token = AuthToken::default(); + assert!(empty_token.is_empty()); + } + + #[test] + fn test_auth_token_cache() { + let mut cache = AuthTokenCache::new(3); + assert_eq!(cache.max_size(), 3); + assert!(cache.is_empty()); + + // Store tokens + cache.store(1, vec![1, 2, 3]).unwrap(); + cache.store(2, vec![4, 5, 6]).unwrap(); + assert_eq!(cache.len(), 2); + assert!(!cache.is_full()); + + // Get token + assert_eq!(cache.get(1), Some(&vec![1, 2, 3])); + assert_eq!(cache.get(2), Some(&vec![4, 5, 6])); + assert_eq!(cache.get(3), None); + + // Store with auto-alias + let alias = cache.store_with_auto_alias(vec![7, 8, 9]).unwrap(); + assert!(cache.is_full()); + + // Cache overflow + assert_eq!( + cache.store(99, vec![10, 11]), + Err(AuthTokenCacheError::CacheOverflow) + ); + + // Duplicate alias + cache.remove(alias); + assert_eq!( + cache.store(1, vec![10, 11]), + Err(AuthTokenCacheError::DuplicateAlias(1)) + ); + + // Remove and clear + assert!(cache.remove(1).is_some()); + cache.clear(); + assert!(cache.is_empty()); + } + + #[test] + fn test_auth_token_cache_disabled() { + let mut cache = AuthTokenCache::new(0); + assert_eq!( + cache.store(1, vec![1, 2, 3]), + Err(AuthTokenCacheError::CacheDisabled) + ); + assert_eq!( + cache.store_with_auto_alias(vec![1, 2, 3]), + Err(AuthTokenCacheError::CacheDisabled) + ); + } +} diff --git a/moq-transport/src/setup/client.rs b/moq-transport/src/setup/client.rs index edefb32e..2d5b7f7e 100644 --- a/moq-transport/src/setup/client.rs +++ b/moq-transport/src/setup/client.rs @@ -1,14 +1,9 @@ -use super::Versions; use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; /// Sent by the client to setup the session. -/// This CLIENT_SETUP message is used by moq-transport draft versions 11 and later. -/// Id = 0x20 vs 0x40 for versions <= 10. +/// Draft-16: version negotiation uses ALPN; no Versions field in CLIENT_SETUP. #[derive(Debug)] pub struct Client { - /// The list of supported versions in preferred order. - pub versions: Versions, - /// Setup Parameters, ie: PATH, MAX_REQUEST_ID, /// MAX_AUTH_TOKEN_CACHE_SIZE, AUTHORIZATION_TOKEN, etc. pub params: KeyValuePairs, @@ -26,10 +21,9 @@ impl Decode for Client { let _len = u16::decode(r)?; // TODO: Check the length of the message. - let versions = Versions::decode(r)?; let params = KeyValuePairs::decode(r)?; - Ok(Self { versions, params }) + Ok(Self { params }) } } @@ -45,7 +39,6 @@ impl Encode for Client { // write the length later, to avoid the copy of the message bytes? let mut buf = Vec::new(); - self.versions.encode(&mut buf).unwrap(); self.params.encode(&mut buf).unwrap(); // Make sure buf.len() <= u16::MAX @@ -66,7 +59,7 @@ impl Encode for Client { #[cfg(test)] mod tests { use super::*; - use crate::setup::{ParameterType, Version}; + use crate::setup::ParameterType; use bytes::BytesMut; #[test] @@ -76,26 +69,22 @@ mod tests { let mut params = KeyValuePairs::default(); params.set_bytesvalue(ParameterType::Path.into(), "testpath".as_bytes().to_vec()); - let client = Client { - versions: [Version::DRAFT_13].into(), - params, - }; + let client = Client { params }; client.encode(&mut buf).unwrap(); + // Draft-16: no Versions field, just Type + Length + Parameters #[rustfmt::skip] assert_eq!( buf.to_vec(), vec![ - 0x20, // Type - 0x00, 0x14, // Length - 0x01, // 1 Version - 0xC0, 0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x0D, // Version DRAFT_13 (0xff00000D) - 0x01, // 1 Param - 0x01, 0x08, 0x74, 0x65, 0x73, 0x74, 0x70, 0x61, 0x74, 0x68, // Key=1 (Path), Value="testpath" + 0x20, // Type (CLIENT_SETUP) + 0x00, 0x0b, // Length = 11 bytes + 0x01, // 1 Parameter (count) + // Delta=1 (Path), Length=8, "testpath" + 0x01, 0x08, 0x74, 0x65, 0x73, 0x74, 0x70, 0x61, 0x74, 0x68, ] ); let decoded = Client::decode(&mut buf).unwrap(); - assert_eq!(decoded.versions, client.versions); assert_eq!(decoded.params, client.params); } } diff --git a/moq-transport/src/setup/mod.rs b/moq-transport/src/setup/mod.rs index 44e3664b..1098ed7a 100644 --- a/moq-transport/src/setup/mod.rs +++ b/moq-transport/src/setup/mod.rs @@ -4,14 +4,16 @@ //! The client sends the [Client] message and the server responds with the [Server] message. //! Both sides negotate the [Version] and [Role]. +mod auth_token; mod client; mod param_types; mod server; mod version; +pub use auth_token::*; pub use client::*; pub use param_types::*; pub use server::*; pub use version::*; -pub const ALPN: &[u8] = b"moq-00"; +pub const ALPN: &[u8] = b"moqt-16"; diff --git a/moq-transport/src/setup/param_types.rs b/moq-transport/src/setup/param_types.rs index 2f4e9862..65f731e5 100644 --- a/moq-transport/src/setup/param_types.rs +++ b/moq-transport/src/setup/param_types.rs @@ -7,7 +7,11 @@ pub enum ParameterType { AuthorizationToken = 0x3, MaxAuthTokenCacheSize = 0x4, Authority = 0x5, + /// Maximum number of Range pairs allowed per subscription/fetch (PR #1518) + MaxFilterRanges = 0x6, MOQTImplementation = 0x7, + /// Maximum value for MaxTracksSelected parameter in TRACK_FILTER (PR #1518) + MaxTracksSelected = 0x8, } impl From for u64 { diff --git a/moq-transport/src/setup/server.rs b/moq-transport/src/setup/server.rs index 3fae91f8..7880228b 100644 --- a/moq-transport/src/setup/server.rs +++ b/moq-transport/src/setup/server.rs @@ -1,14 +1,9 @@ -use super::Version; use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; /// Sent by the server in response to a client setup. -/// This SERVER_SETUP message is used by moq-transport draft versions 11 and later. -/// Id = 0x21 vs 0x41 for versions <= 10. +/// Draft-16: version negotiation uses ALPN; no Versions field in SERVER_SETUP. #[derive(Debug)] pub struct Server { - /// The list of supported versions in preferred order. - pub version: Version, - /// Setup Parameters, ie: MAX_REQUEST_ID, MAX_AUTH_TOKEN_CACHE_SIZE, /// AUTHORIZATION_TOKEN, etc. pub params: KeyValuePairs, @@ -26,10 +21,9 @@ impl Decode for Server { let _len = u16::decode(r)?; // TODO: Check the length of the message. - let version = Version::decode(r)?; let params = KeyValuePairs::decode(r)?; - Ok(Self { version, params }) + Ok(Self { params }) } } @@ -44,7 +38,6 @@ impl Encode for Server { // write the length later, to avoid the copy of the message bytes? let mut buf = Vec::new(); - self.version.encode(&mut buf).unwrap(); self.params.encode(&mut buf).unwrap(); // Make sure buf.len() <= u16::MAX @@ -75,27 +68,24 @@ mod tests { let mut params = KeyValuePairs::default(); params.set_intvalue(ParameterType::MaxRequestId.into(), 1000); - let server = Server { - version: Version::DRAFT_14, - params, - }; + let server = Server { params }; server.encode(&mut buf).unwrap(); + // Draft-16: no Versions field, just Type + Length + Parameters #[rustfmt::skip] assert_eq!( buf.to_vec(), vec![ - 0x21, // Type - 0x00, 0x0c, // Length - 0xC0, 0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x0E, // Version DRAFT_14 (0xff00000E) - 0x01, // 1 Param - 0x02, 0x43, 0xe8, // Key=2 (MaxRequestId), Value=1000 + 0x21, // Type (SERVER_SETUP) + 0x00, 0x04, // Length = 4 bytes + 0x01, // 1 Parameter (count) + // Delta=2 (MaxRequestId), Value=1000 + 0x02, 0x43, 0xe8, ] ); let decoded = Server::decode(&mut buf).unwrap(); - assert_eq!(decoded.version, server.version); assert_eq!(decoded.params, server.params); } } diff --git a/moq-transport/src/setup/version.rs b/moq-transport/src/setup/version.rs index fa896e7d..2fb41ae4 100644 --- a/moq-transport/src/setup/version.rs +++ b/moq-transport/src/setup/version.rs @@ -23,6 +23,12 @@ impl Version { /// https://www.ietf.org/archive/id/draft-ietf-moq-transport-14.html pub const DRAFT_14: Version = Version(0xff00000e); + + /// https://www.ietf.org/archive/id/draft-ietf-moq-transport-15.html + pub const DRAFT_15: Version = Version(0xff00000f); + + /// https://www.ietf.org/archive/id/draft-ietf-moq-transport-16.html + pub const DRAFT_16: Version = Version(0xff000010); } impl From for Version {