diff --git a/livekit-android-test/src/main/java/io/livekit/android/test/mock/MockPeerConnection.kt b/livekit-android-test/src/main/java/io/livekit/android/test/mock/MockPeerConnection.kt index e9d825af..e2992424 100644 --- a/livekit-android-test/src/main/java/io/livekit/android/test/mock/MockPeerConnection.kt +++ b/livekit-android-test/src/main/java/io/livekit/android/test/mock/MockPeerConnection.kt @@ -47,20 +47,27 @@ class MockPeerConnection( var localDesc: SessionDescription? = null var remoteDesc: SessionDescription? = null + private val signalStateMachine = SignalStateMachine { newState -> + observer?.onSignalingChange(newState) + } + private val transceivers = mutableListOf() + override fun getLocalDescription(): SessionDescription? = localDesc - override fun setLocalDescription(observer: SdpObserver?, sdp: SessionDescription?) { - if (sdp?.description?.isEmpty() == true) { - observer?.onSetFailure("empty local description") + override fun setLocalDescription(observer: SdpObserver, sdp: SessionDescription) { + if (sdp.description.isEmpty()) { + observer.onSetFailure("empty local description") return } - // https://w3c.github.io/webrtc-pc/#fig-non-normative-signaling-state-transitions-diagram-method-calls-abbreviated - if (signalingState() == SignalingState.STABLE) { - remoteDesc = null + try { + signalStateMachine.handleSetMethod(Location.LOCAL, sdp.type) + } catch (e: IllegalStateException) { + observer.onSetFailure(e.message) } + localDesc = sdp - observer?.onSetSuccess() + observer.onSetSuccess() if (signalingState() == SignalingState.STABLE) { moveToIceConnectionState(IceConnectionState.CONNECTED) @@ -68,18 +75,20 @@ class MockPeerConnection( } override fun getRemoteDescription(): SessionDescription? = remoteDesc - override fun setRemoteDescription(observer: SdpObserver?, sdp: SessionDescription?) { - if (sdp?.description?.isEmpty() == true) { - observer?.onSetFailure("empty remote description") + override fun setRemoteDescription(observer: SdpObserver, sdp: SessionDescription) { + if (sdp.description.isEmpty()) { + observer.onSetFailure("empty remote description") return } - // https://w3c.github.io/webrtc-pc/#fig-non-normative-signaling-state-transitions-diagram-method-calls-abbreviated - if (signalingState() == SignalingState.STABLE) { - localDesc = null + try { + signalStateMachine.handleSetMethod(Location.REMOTE, sdp.type) + } catch (e: IllegalStateException) { + observer.onSetFailure(e.message) } + remoteDesc = sdp - observer?.onSetSuccess() + observer.onSetSuccess() if (signalingState() == SignalingState.STABLE) { moveToIceConnectionState(IceConnectionState.CONNECTED) @@ -211,29 +220,7 @@ class MockPeerConnection( override fun stopRtcEventLog() { } - override fun signalingState(): SignalingState { - if (closed) { - return SignalingState.CLOSED - } - - if ((localDesc?.type == null && remoteDesc?.type == null) || - (localDesc?.type == SessionDescription.Type.OFFER && - remoteDesc?.type == SessionDescription.Type.ANSWER) || - (localDesc?.type == SessionDescription.Type.ANSWER && - remoteDesc?.type == SessionDescription.Type.OFFER) - ) { - return SignalingState.STABLE - } - - if (localDesc?.type == SessionDescription.Type.OFFER && remoteDesc?.type == null) { - return SignalingState.HAVE_LOCAL_OFFER - } - if (remoteDesc?.type == SessionDescription.Type.OFFER && localDesc?.type == null) { - return SignalingState.HAVE_REMOTE_OFFER - } - - throw IllegalStateException("Illegal signalling state? localDesc: $localDesc, remoteDesc: $remoteDesc") - } + override fun signalingState(): SignalingState = signalStateMachine.state private var iceConnectionState = IceConnectionState.NEW set(value) { @@ -312,6 +299,7 @@ class MockPeerConnection( override fun dispose() { iceConnectionState = IceConnectionState.CLOSED closed = true + signalStateMachine.close() transceivers.forEach { t -> t.dispose() } transceivers.clear() @@ -319,3 +307,78 @@ class MockPeerConnection( override fun getNativePeerConnection(): Long = 0L } + +private class SignalStateMachine( + var changeListener: ((PeerConnection.SignalingState) -> Unit)? = null, +) { + var state = PeerConnection.SignalingState.STABLE + set(value) { + val changed = field != value + field = value + if (changed) { + changeListener?.invoke(field) + } + } + + /** + * Throws if would go to invalid state. + * + * Does not handle PRANSWER or ROLLBACK. + * + * State machine as shown here: + * https://w3c.github.io/webrtc-pc/#fig-non-normative-signaling-state-transitions-diagram-method-calls-abbreviated + */ + @Throws(IllegalStateException::class) + fun handleSetMethod(location: Location, type: SessionDescription.Type) { + fun throwException() { + throw IllegalStateException("Illegal set of $location with $type on signal state $state") + } + when (state) { + PeerConnection.SignalingState.STABLE -> { + // Can only accept offers from stable + if (type != SessionDescription.Type.OFFER) { + throwException() + } + state = when (location) { + Location.LOCAL -> PeerConnection.SignalingState.HAVE_LOCAL_OFFER + Location.REMOTE -> PeerConnection.SignalingState.HAVE_REMOTE_OFFER + } + } + + PeerConnection.SignalingState.HAVE_LOCAL_OFFER -> { + if (location == Location.LOCAL && type == SessionDescription.Type.OFFER) { + // legal, does not change state. + } else if (location == Location.REMOTE && type == SessionDescription.Type.ANSWER) { + state = PeerConnection.SignalingState.STABLE + } else { + throwException() + } + } + + PeerConnection.SignalingState.HAVE_REMOTE_OFFER -> { + if (location == Location.REMOTE && type == SessionDescription.Type.OFFER) { + // legal, does not change state. + } else if (location == Location.LOCAL && type == SessionDescription.Type.ANSWER) { + state = PeerConnection.SignalingState.STABLE + } else { + throwException() + } + } + + PeerConnection.SignalingState.HAVE_LOCAL_PRANSWER -> TODO() + PeerConnection.SignalingState.HAVE_REMOTE_PRANSWER -> TODO() + PeerConnection.SignalingState.CLOSED -> { + throw IllegalStateException("Closed") + } + } + } + + fun close() { + state = PeerConnection.SignalingState.CLOSED + } +} + +private enum class Location { + LOCAL, + REMOTE +} diff --git a/livekit-android-test/src/test/java/io/livekit/android/test/mock/MockPeerConnectionTest.kt b/livekit-android-test/src/test/java/io/livekit/android/test/mock/MockPeerConnectionTest.kt new file mode 100644 index 00000000..27027751 --- /dev/null +++ b/livekit-android-test/src/test/java/io/livekit/android/test/mock/MockPeerConnectionTest.kt @@ -0,0 +1,168 @@ +/* + * Copyright 2025 LiveKit, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.livekit.android.test.mock + +import livekit.org.webrtc.PeerConnection +import livekit.org.webrtc.SdpObserver +import livekit.org.webrtc.SessionDescription +import org.junit.Assert.assertEquals +import org.junit.Before +import org.junit.Test +import org.mockito.Mockito.mock +import org.mockito.kotlin.any +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +class MockPeerConnectionTest { + + lateinit var pc: MockPeerConnection + + @Before + fun setup() { + pc = MockPeerConnection(PeerConnection.RTCConfiguration(emptyList()), null) + } + + @Test + fun publisherNegotiation() { + run { + val observer = mock() + pc.setLocalDescription( + observer, + SessionDescription(SessionDescription.Type.OFFER, "local_offer"), + ) + verify(observer, times(1)).onSetSuccess() + assertEquals(PeerConnection.SignalingState.HAVE_LOCAL_OFFER, pc.signalingState()) + } + run { + val observer = mock() + pc.setRemoteDescription( + observer, + SessionDescription(SessionDescription.Type.ANSWER, "remote_answer"), + ) + verify(observer, times(1)).onSetSuccess() + assertEquals(PeerConnection.SignalingState.STABLE, pc.signalingState()) + } + } + + @Test + fun subscriberNegotiation() { + run { + val observer = mock() + pc.setRemoteDescription( + observer, + SessionDescription(SessionDescription.Type.OFFER, "remote_offer"), + ) + verify(observer, times(1)).onSetSuccess() + assertEquals(PeerConnection.SignalingState.HAVE_REMOTE_OFFER, pc.signalingState()) + } + run { + val observer = mock() + pc.setLocalDescription( + observer, + SessionDescription(SessionDescription.Type.ANSWER, "local_answer"), + ) + verify(observer, times(1)).onSetSuccess() + assertEquals(PeerConnection.SignalingState.STABLE, pc.signalingState()) + } + } + + @Test + fun cannotSetAnswerOnStable() { + run { + val observer = mock() + pc.setLocalDescription( + observer, + SessionDescription(SessionDescription.Type.ANSWER, "local_answer"), + ) + verify(observer, times(1)).onSetFailure(any()) + assertEquals(PeerConnection.SignalingState.STABLE, pc.signalingState()) + } + run { + val observer = mock() + pc.setRemoteDescription( + observer, + SessionDescription(SessionDescription.Type.ANSWER, "remote_answer"), + ) + verify(observer, times(1)).onSetFailure(any()) + assertEquals(PeerConnection.SignalingState.STABLE, pc.signalingState()) + } + } + + @Test + fun cannotIllegalSetOnHaveLocalOffer() { + run { + val observer = mock() + pc.setLocalDescription( + observer, + SessionDescription(SessionDescription.Type.OFFER, "local_offer"), + ) + assertEquals(PeerConnection.SignalingState.HAVE_LOCAL_OFFER, pc.signalingState()) + } + + run { + val observer = mock() + pc.setLocalDescription( + observer, + SessionDescription(SessionDescription.Type.ANSWER, "local_answer"), + ) + verify(observer, times(1)).onSetFailure(any()) + assertEquals(PeerConnection.SignalingState.HAVE_LOCAL_OFFER, pc.signalingState()) + } + + run { + val observer = mock() + pc.setRemoteDescription( + observer, + SessionDescription(SessionDescription.Type.OFFER, "remote_offer"), + ) + verify(observer, times(1)).onSetFailure(any()) + assertEquals(PeerConnection.SignalingState.HAVE_LOCAL_OFFER, pc.signalingState()) + } + } + + @Test + fun cannotIllegalSetOnHaveRemoteOffer() { + run { + val observer = mock() + pc.setRemoteDescription( + observer, + SessionDescription(SessionDescription.Type.OFFER, "remote_offer"), + ) + assertEquals(PeerConnection.SignalingState.HAVE_REMOTE_OFFER, pc.signalingState()) + } + + run { + val observer = mock() + pc.setLocalDescription( + observer, + SessionDescription(SessionDescription.Type.OFFER, "local_offer"), + ) + verify(observer, times(1)).onSetFailure(any()) + assertEquals(PeerConnection.SignalingState.HAVE_REMOTE_OFFER, pc.signalingState()) + } + + run { + val observer = mock() + pc.setRemoteDescription( + observer, + SessionDescription(SessionDescription.Type.ANSWER, "remote_offer"), + ) + verify(observer, times(1)).onSetFailure(any()) + assertEquals(PeerConnection.SignalingState.HAVE_REMOTE_OFFER, pc.signalingState()) + } + } +}