diff --git a/Cargo.lock b/Cargo.lock index 490e2e56..4eab2305 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1346,6 +1346,7 @@ dependencies = [ "serde_json", "thiserror 2.0.17", "tokio", + "tokio-util", "tower-http", "tracing", "tracing-subscriber", diff --git a/moq-relay-ietf/Cargo.toml b/moq-relay-ietf/Cargo.toml index 780c176c..1430a943 100644 --- a/moq-relay-ietf/Cargo.toml +++ b/moq-relay-ietf/Cargo.toml @@ -33,7 +33,7 @@ url = "2" # Async stuff tokio = { version = "1", features = ["full"] } -# tokio-util = "0.7" +tokio-util = "0.7" futures = "0.3" async-trait = "0.1" diff --git a/moq-relay-ietf/src/producer.rs b/moq-relay-ietf/src/producer.rs index 15890a3c..2f42c106 100644 --- a/moq-relay-ietf/src/producer.rs +++ b/moq-relay-ietf/src/producer.rs @@ -9,7 +9,7 @@ use moq_transport::{ use crate::{ metrics::{GaugeGuard, TimingGuard}, - Locals, RemotesConsumer, + Locals, RemoteManager, }; /// Producer of tracks to a remote Subscriber @@ -17,7 +17,7 @@ use crate::{ pub struct Producer { publisher: Publisher, locals: Locals, - remotes: Option, + remotes: RemoteManager, /// The resolved scope identity for this session, if any. /// Produced by `Coordinator::resolve_scope()` from the connection path. /// Passed to locals/remotes to isolate namespace lookups. @@ -28,7 +28,7 @@ impl Producer { pub fn new( publisher: Publisher, locals: Locals, - remotes: Option, + remotes: RemoteManager, scope: Option, ) -> Self { Self { @@ -122,40 +122,33 @@ impl Producer { } } - if let Some(remotes) = self.remotes { - // Check remote tracks second, and serve from remote if possible - match remotes.route(self.scope.as_deref(), &namespace).await { - Ok(remote) => { - if let Some(remote) = remote { - if let Some(track) = remote.subscribe(&namespace, &track_name)? { - let ns = namespace.to_utf8_path(); - tracing::info!(namespace = %ns, track = %track_name, source = "remote", "serving subscribe from remote: {:?}", track.info); - // Update label to indicate remote source, timing recorded on drop - timing_guard.set_label("source", "remote"); - // Track active tracks - decrements when serve completes - let _track_guard = GaugeGuard::new("moq_relay_active_tracks"); - return Ok(subscribed.serve(track.reader).await?); - } - } - } - Err(e) => { - // Route error = infrastructure failure (couldn't reach coordinator/upstream) - // This is different from "not found" - we don't know if the track exists + match self + .remotes + .subscribe(self.scope.as_deref(), &namespace, &track_name) + .await + { + Ok(track) => { + if let Some(track) = track { let ns = namespace.to_utf8_path(); - tracing::error!(namespace = %ns, track = %track_name, error = %e, "failed to route to remote: {}", e); - timing_guard.set_label("source", "route_error"); - metrics::counter!("moq_relay_subscribe_route_errors_total").increment(1); - - // Return an internal error rather than "not found" since we couldn't check - // TODO: Consider returning a more specific error to the subscriber - let err = ServeError::internal_ctx(format!( - "route error for namespace '{}': {}", - namespace, e - )); - subscribed.close(err.clone())?; - return Err(err.into()); + tracing::info!(namespace = %ns, track = %track_name, source = "remote", "serving subscribe from remote: {:?}", track.info); + timing_guard.set_label("source", "remote"); + let _track_guard = GaugeGuard::new("moq_relay_active_tracks"); + return Ok(subscribed.serve(track).await?); } } + Err(e) => { + let ns = namespace.to_utf8_path(); + tracing::error!(namespace = %ns, track = %track_name, error = %e, "failed to route to remote: {}", e); + timing_guard.set_label("source", "route_error"); + metrics::counter!("moq_relay_subscribe_route_errors_total").increment(1); + + let err = ServeError::internal_ctx(format!( + "route error for namespace '{}': {}", + namespace, e + )); + subscribed.close(err.clone())?; + return Err(err.into()); + } } // Track not found - we checked all sources and the track doesn't exist diff --git a/moq-relay-ietf/src/relay.rs b/moq-relay-ietf/src/relay.rs index 5beea1b8..ee088a84 100644 --- a/moq-relay-ietf/src/relay.rs +++ b/moq-relay-ietf/src/relay.rs @@ -9,10 +9,7 @@ use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use moq_native_ietf::quic::{self, Endpoint}; use url::Url; -use crate::{ - metrics::GaugeGuard, Consumer, Coordinator, Locals, Producer, Remotes, RemotesConsumer, - RemotesProducer, Session, -}; +use crate::{metrics::GaugeGuard, Consumer, Coordinator, Locals, Producer, RemoteManager, Session}; // A type alias for boxed future type ServerFuture = Pin< @@ -64,7 +61,7 @@ pub struct Relay { announce_url: Option, mlog_dir: Option, locals: Locals, - remotes: Option<(RemotesProducer, RemotesConsumer)>, + remotes: RemoteManager, coordinator: Arc, } @@ -109,18 +106,14 @@ impl Relay { .collect::>(); // Create remote manager - uses coordinator for namespace lookups - let remotes = Remotes { - coordinator: config.coordinator.clone(), - quic: remote_clients[0].clone(), - } - .produce(); + let remotes = RemoteManager::new(config.coordinator.clone(), remote_clients); Ok(Self { quic_endpoints: endpoints, announce_url: config.announce, mlog_dir: config.mlog_dir, locals, - remotes: Some(remotes), + remotes, coordinator: config.coordinator, }) } @@ -130,10 +123,7 @@ impl Relay { let mut tasks = FuturesUnordered::new(); // Split remotes producer/consumer and spawn producer task - let remotes = self.remotes.map(|(producer, consumer)| { - tasks.push(producer.run().boxed()); - consumer - }); + let remotes = self.remotes; // Start the forwarder, if any let forward_producer = if let Some(url) = &self.announce_url { diff --git a/moq-relay-ietf/src/remote.rs b/moq-relay-ietf/src/remote.rs index 2adc8582..d7120f74 100644 --- a/moq-relay-ietf/src/remote.rs +++ b/moq-relay-ietf/src/remote.rs @@ -2,269 +2,175 @@ // SPDX-License-Identifier: MIT OR Apache-2.0 use std::collections::HashMap; - -use std::collections::VecDeque; -use std::fmt; use std::net::SocketAddr; -use std::ops; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use std::sync::Weak; - -/// Cache key for upstream relay-to-relay connections. -/// -/// Keyed by both URL and destination address so that connections are -/// reused only when both match. This matters when a [`Coordinator`] -/// returns the same URL for different namespaces (e.g. a shared relay -/// hostname) but distinguishes destinations via [`NamespaceOrigin::addr`]. -/// Without the address in the key, all namespaces that share a URL -/// would be routed through a single cached connection to whichever -/// upstream host was contacted first. -type RemoteCacheKey = (Url, Option); -use futures::stream::FuturesUnordered; -use futures::FutureExt; -use futures::StreamExt; use moq_native_ietf::quic; use moq_transport::coding::TrackNamespace; -use moq_transport::serve::{Track, TrackReader, TrackWriter}; -use moq_transport::watch::State; +use moq_transport::serve::{Track, TrackReader}; +use tokio::sync::Mutex; +use tokio_util::sync::CancellationToken; use url::Url; -use crate::{metrics::GaugeGuard, Coordinator}; - -/// Information about remote origins. -pub struct Remotes { - /// The client we use to fetch/store origin information. - pub coordinator: Arc, - - // A QUIC endpoint we'll use to fetch from other origins. - pub quic: quic::Client, -} - -impl Remotes { - pub fn produce(self) -> (RemotesProducer, RemotesConsumer) { - let (send, recv) = State::default().split(); - let info = Arc::new(self); - - let producer = RemotesProducer::new(info.clone(), send); - let consumer = RemotesConsumer::new(info, recv); +use crate::{metrics::GaugeGuard, Coordinator, CoordinatorError}; - (producer, consumer) - } -} - -#[derive(Default)] -struct RemotesState { - lookup: HashMap, - requested: VecDeque, -} +/// Cache key for upstream relay-to-relay connections. +/// +/// Keyed by both URL and destination address so that connections are reused +/// only when both match. +type RemoteCacheKey = (Url, Option); -// Clone for convenience, but there should only be one instance of this +/// Manages connections to remote relays. +/// +/// When a subscription request comes in for a namespace that isn't local, +/// RemoteManager uses the coordinator to find which remote relay serves it, +/// establishes a connection if needed, and subscribes to the track. #[derive(Clone)] -pub struct RemotesProducer { - info: Arc, - state: State, +pub struct RemoteManager { + coordinator: Arc, + clients: Vec, + remotes: Arc>>, } -impl RemotesProducer { - fn new(info: Arc, state: State) -> Self { - Self { info, state } - } - - /// Block until the next remote requested by a consumer. - async fn next(&mut self) -> Option { - loop { - { - let state = self.state.lock(); - if !state.requested.is_empty() { - return state.into_mut()?.requested.pop_front(); - } - - state.modified()? - } - .await; +impl RemoteManager { + /// Create a new RemoteManager. + pub fn new(coordinator: Arc, clients: Vec) -> Self { + Self { + coordinator, + clients, + remotes: Arc::new(Mutex::new(HashMap::new())), } } - /// Run the remotes producer to serve remote requests. - pub async fn run(mut self) -> anyhow::Result<()> { - let mut tasks = FuturesUnordered::new(); - - loop { - tokio::select! { - Some(mut remote) = self.next() => { - let url = remote.url.clone(); - let cache_key = (url.clone(), remote.addr); - - // Spawn a task to serve the remote - tasks.push(async move { - let info = remote.info.clone(); - let remote_url = url.to_string(); - - tracing::warn!(remote_url = %remote_url, "serving remote: {:?}", info); + /// Subscribe to a track from a remote relay. + /// + /// `scope` is the resolved scope identity from `Coordinator::resolve_scope()`, + /// passed through to the coordinator's `lookup()` to scope the search. + /// + /// Returns None if the namespace isn't found in any remote relay. + pub async fn subscribe( + &self, + scope: Option<&str>, + namespace: &TrackNamespace, + track_name: &str, + ) -> anyhow::Result> { + let (origin, client) = match self.coordinator.lookup(scope, namespace).await { + Ok(result) => result, + Err(CoordinatorError::NamespaceNotFound) => return Ok(None), + Err(err) => return Err(err.into()), + }; - // Run the remote producer - if let Err(err) = remote.run().await { - tracing::warn!(remote_url = %remote_url, error = %err, "failed serving remote: {:?}, error: {}", info, err); - } + let url = origin.url(); + let cache_key = (url.clone(), origin.addr()); + + let remote = match self + .get_or_connect(cache_key.clone(), client.as_ref()) + .await + { + Ok(remote) => remote, + Err(err) => { + tracing::error!(remote_url = %url, error = %err, "failed to connect to remote relay: {}", err); + self.remove(&cache_key).await; + return Err(err); + } + }; - cache_key - }); + match remote + .subscribe(namespace.clone(), track_name.to_string()) + .await + { + Ok(reader) => Ok(reader), + Err(err) => { + if !remote.is_connected() { + tracing::warn!(remote_url = %url, "remote connection is dead, removing from cache"); + self.remove(&cache_key).await; } - // Handle finished remote producers - res = tasks.next(), if !tasks.is_empty() => { - let cache_key = res.unwrap(); - - if let Some(mut state) = self.state.lock_mut() { - state.lookup.remove(&cache_key); - } - }, - else => return Ok(()), + Err(err) } } } -} - -impl ops::Deref for RemotesProducer { - type Target = Remotes; - - fn deref(&self) -> &Self::Target { - &self.info - } -} -#[derive(Clone)] -pub struct RemotesConsumer { - pub info: Arc, - state: State, -} - -impl RemotesConsumer { - fn new(info: Arc, state: State) -> Self { - Self { info, state } - } - - /// Route to a remote origin based on the namespace. - /// - /// `scope` is the resolved scope identity (from `Coordinator::resolve_scope()`), - /// passed through to the coordinator's `lookup()` to scope the search. - pub async fn route( + /// Get an existing remote connection or create a new one. + async fn get_or_connect( &self, - scope: Option<&str>, - namespace: &TrackNamespace, - ) -> anyhow::Result> { - // Always fetch the origin instead of using the (potentially invalid) cache. - let (origin, client) = self.coordinator.lookup(scope, namespace).await?; - - let cache_key = (origin.url(), origin.addr()); + cache_key: RemoteCacheKey, + client: Option<&quic::Client>, + ) -> anyhow::Result { + let mut remotes = self.remotes.lock().await; + + if let Some(remote) = remotes.get(&cache_key) { + if remote.is_connected() { + return Ok(remote.clone()); + } - // Check if we already have a remote for this origin - let state = self.state.lock(); - if let Some(remote) = state.lookup.get(&cache_key).cloned() { - return Ok(Some(remote)); + tracing::info!(remote_url = %cache_key.0, "removing dead connection to remote relay"); + remotes.remove(&cache_key); } - // Create a new remote for this origin - let mut state = match state.into_mut() { - Some(state) => state, - None => return Ok(None), - }; - - let remote = Remote { - url: origin.url(), - remotes: self.info.clone(), - addr: origin.addr(), - client, + let client = match client { + Some(client) => client, + None => self.clients.first().ok_or_else(|| { + anyhow::anyhow!("no QUIC clients configured for remote connections") + })?, }; - // Produce the remote - let (writer, reader) = remote.produce(); - state.requested.push_back(writer); + tracing::info!(remote_url = %cache_key.0, "connecting to remote relay"); + let remote = Remote::connect(cache_key.0.clone(), cache_key.1, client).await?; + remotes.insert(cache_key, remote.clone()); - // Insert the remote into our Map, keyed by both URL and destination address - state.lookup.insert(cache_key, reader.clone()); - - Ok(Some(reader)) + Ok(remote) } -} -impl ops::Deref for RemotesConsumer { - type Target = Remotes; - - fn deref(&self) -> &Self::Target { - &self.info + /// Remove a remote connection (called when connection fails). + async fn remove(&self, cache_key: &RemoteCacheKey) { + let mut remotes = self.remotes.lock().await; + if let Some(remote) = remotes.remove(cache_key) { + remote.shutdown(); + } } -} - -pub struct Remote { - pub remotes: Arc, - pub url: Url, - pub addr: Option, - pub client: Option, -} -impl fmt::Debug for Remote { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Remote") - .field("url", &self.url.to_string()) - .finish() + /// Shutdown all remote connections. + pub async fn shutdown(&self) { + let mut remotes = self.remotes.lock().await; + for (cache_key, remote) in remotes.drain() { + tracing::info!(remote_url = %cache_key.0, "shutting down remote connection"); + remote.shutdown(); + } } } -impl ops::Deref for Remote { - type Target = Remotes; - - fn deref(&self) -> &Self::Target { - &self.remotes - } +/// A connection to a single remote relay with its own QUIC client. +#[derive(Clone)] +pub struct Remote { + url: Url, + subscriber: moq_transport::session::Subscriber, + /// Track subscriptions - maps (namespace, track_name) to track reader + tracks: Arc>>, + /// Flag indicating if the connection is still alive + connected: Arc, + /// Cancellation token for the session task + cancel: CancellationToken, } impl Remote { - /// Create a new broadcast. - pub fn produce(self) -> (RemoteProducer, RemoteConsumer) { - let (send, recv) = State::default().split(); - let info = Arc::new(self); - - let consumer = RemoteConsumer::new(info.clone(), recv); - let producer = RemoteProducer::new(info, send); - - (producer, consumer) - } -} - -#[derive(Default)] -struct RemoteState { - tracks: HashMap<(TrackNamespace, String), RemoteTrackWeak>, - requested: VecDeque, -} - -pub struct RemoteProducer { - pub info: Arc, - state: State, -} - -impl RemoteProducer { - fn new(info: Arc, state: State) -> Self { - Self { info, state } - } - - pub async fn run(&mut self) -> anyhow::Result<()> { - let client = if let Some(client) = &self.info.client { - client - } else { - &self.quic + /// Connect to a remote relay with a dedicated QUIC client. + async fn connect( + url: Url, + addr: Option, + client: &quic::Client, + ) -> anyhow::Result { + let (session, _quic_client_initial_cid, transport) = match client.connect(&url, addr).await + { + Ok(session) => session, + Err(err) => { + metrics::counter!("moq_relay_upstream_errors_total", "stage" => "connect") + .increment(1); + return Err(err); + } }; - // TODO reuse QUIC and MoQ sessions - let (session, _quic_client_initial_cid, transport) = - match client.connect(&self.url, self.addr).await { - Ok(session) => session, - Err(err) => { - metrics::counter!("moq_relay_upstream_errors_total", "stage" => "connect") - .increment(1); - return Err(err); - } - }; + let (session, subscriber) = match moq_transport::session::Subscriber::connect(session, transport).await { Ok(session) => session, @@ -275,190 +181,97 @@ impl RemoteProducer { } }; - // Track established upstream connections - decrements when this function returns. - // Placed after successful connect + session setup so the gauge only reflects - // connections that are actually serving, not in-flight attempts. - let _upstream_guard = GaugeGuard::new("moq_relay_upstream_connections"); + let connected = Arc::new(AtomicBool::new(true)); + let cancel = CancellationToken::new(); + let upstream_guard = GaugeGuard::new("moq_relay_upstream_connections"); - // Run the session - let mut session = session.run().boxed(); - let mut tasks = FuturesUnordered::new(); + let session_url = url.clone(); + let session_connected = connected.clone(); + let session_cancel = cancel.clone(); - let mut done = None; - - // Serve requested tracks - loop { + tokio::spawn(async move { + let _upstream_guard = upstream_guard; tokio::select! { - track = self.next(), if done.is_none() => { - let track = match track { - Ok(Some(track)) => track, - Ok(None) => { done = Some(Ok(())); continue }, - Err(err) => { done = Some(Err(err)); continue }, - }; - - let info = track.info.clone(); - let mut subscriber = subscriber.clone(); - - tasks.push(async move { - if let Err(err) = subscriber.subscribe(track).await { - let namespace = info.namespace.to_utf8_path(); - let track_name = &info.name; - tracing::warn!(namespace = %namespace, track = %track_name, error = %err, "failed serving track: {:?}, error: {}", info, err); - } - }); + result = session.run() => { + if let Err(err) = result { + tracing::warn!(remote_url = %session_url, error = %err, "remote session closed: {}", err); + } else { + tracing::info!(remote_url = %session_url, "remote session closed normally"); + } } - _ = tasks.next(), if !tasks.is_empty() => {}, - - // Keep running the session - res = &mut session, if !tasks.is_empty() || done.is_none() => return Ok(res?), - - else => return done.unwrap(), - } - } - } - - /// Block until the next track requested by a consumer. - async fn next(&self) -> anyhow::Result> { - loop { - let notify = { - let state = self.state.lock(); - - // Check if we have any requested tracks - if !state.requested.is_empty() { - return Ok(state - .into_mut() - .and_then(|mut state| state.requested.pop_front())); + _ = session_cancel.cancelled() => { + tracing::info!(remote_url = %session_url, "remote session cancelled"); } + } - match state.modified() { - Some(notified) => notified, - None => return Ok(None), - } - }; + session_connected.store(false, Ordering::Release); + }); - notify.await - } + Ok(Self { + url, + subscriber, + tracks: Arc::new(Mutex::new(HashMap::new())), + connected, + cancel, + }) } -} - -impl ops::Deref for RemoteProducer { - type Target = Remote; - fn deref(&self) -> &Self::Target { - &self.info + /// Check if the connection is still alive. + pub fn is_connected(&self) -> bool { + self.connected.load(Ordering::Acquire) } -} - -#[derive(Clone)] -pub struct RemoteConsumer { - pub info: Arc, - state: State, -} -impl RemoteConsumer { - fn new(info: Arc, state: State) -> Self { - Self { info, state } + /// Shutdown the remote connection. + pub fn shutdown(&self) { + self.cancel.cancel(); + self.connected.store(false, Ordering::Release); } - /// Request a track from the broadcast. - pub fn subscribe( + /// Subscribe to a track on this remote relay. + pub async fn subscribe( &self, - namespace: &TrackNamespace, - name: &str, - ) -> anyhow::Result> { - let key = (namespace.clone(), name.to_string()); - let state = self.state.lock(); - if let Some(track) = state.tracks.get(&key) { - if let Some(track) = track.upgrade() { - return Ok(Some(track)); - } + namespace: TrackNamespace, + track_name: String, + ) -> anyhow::Result> { + if !self.is_connected() { + anyhow::bail!("remote connection to {} is closed", self.url); } - let mut state = match state.into_mut() { - Some(state) => state, - None => return Ok(None), - }; + let key = (namespace.clone(), track_name.clone()); + let mut tracks = self.tracks.lock().await; - let (writer, reader) = Track::new(namespace.clone(), name.to_string()).produce(); - let reader = RemoteTrackReader::new(reader, self.state.clone()); + if let Some(reader) = tracks.get(&key) { + return Ok(Some(reader.clone())); + } - // Insert the track into our Map so we deduplicate future requests. - state.tracks.insert(key, reader.downgrade()); - state.requested.push_back(writer); + let (writer, reader) = Track::new(namespace, track_name).produce(); + tracks.insert(key.clone(), reader.clone()); + drop(tracks); - Ok(Some(reader)) - } -} + let mut subscriber = self.subscriber.clone(); + let tracks = self.tracks.clone(); + let url = self.url.clone(); -impl ops::Deref for RemoteConsumer { - type Target = Remote; - - fn deref(&self) -> &Self::Target { - &self.info - } -} + tokio::spawn(async move { + tracing::info!(remote_url = %url, namespace = %key.0, track = %key.1, "subscribing to remote track"); -#[derive(Clone)] -pub struct RemoteTrackReader { - pub reader: TrackReader, - drop: Arc, -} + if let Err(err) = subscriber.subscribe(writer).await { + tracing::warn!(remote_url = %url, namespace = %key.0, track = %key.1, error = %err, "failed subscribing to remote track: {}", err); + } -impl RemoteTrackReader { - fn new(reader: TrackReader, parent: State) -> Self { - let drop = Arc::new(RemoteTrackDrop { - parent, - key: (reader.namespace.clone(), reader.name.clone()), + tracks.lock().await.remove(&key); + tracing::debug!(remote_url = %url, namespace = %key.0, track = %key.1, "remote track subscription ended"); }); - Self { reader, drop } - } - - fn downgrade(&self) -> RemoteTrackWeak { - RemoteTrackWeak { - reader: self.reader.clone(), - drop: Arc::downgrade(&self.drop), - } - } -} - -impl ops::Deref for RemoteTrackReader { - type Target = TrackReader; - - fn deref(&self) -> &Self::Target { - &self.reader - } -} - -impl ops::DerefMut for RemoteTrackReader { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.reader - } -} - -struct RemoteTrackWeak { - reader: TrackReader, - drop: Weak, -} - -impl RemoteTrackWeak { - fn upgrade(&self) -> Option { - Some(RemoteTrackReader { - reader: self.reader.clone(), - drop: self.drop.upgrade()?, - }) + Ok(Some(reader)) } } -struct RemoteTrackDrop { - parent: State, - key: (TrackNamespace, String), -} - -impl Drop for RemoteTrackDrop { - fn drop(&mut self) { - if let Some(mut parent) = self.parent.lock_mut() { - parent.tracks.remove(&self.key); - } +impl std::fmt::Debug for Remote { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Remote") + .field("url", &self.url.to_string()) + .field("connected", &self.is_connected()) + .finish() } }