diff --git a/.changeset/good-snails-switch.md b/.changeset/good-snails-switch.md new file mode 100644 index 000000000..77c850e9f --- /dev/null +++ b/.changeset/good-snails-switch.md @@ -0,0 +1,5 @@ +--- +"client-sdk-android": minor +--- + +Implement RPC diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/events/RoomEvent.kt b/livekit-android-sdk/src/main/java/io/livekit/android/events/RoomEvent.kt index a76620d25..c0bdde1c3 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/events/RoomEvent.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/events/RoomEvent.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 LiveKit, Inc. + * Copyright 2023-2025 LiveKit, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -273,6 +273,9 @@ enum class DisconnectReason { MIGRATION, SIGNAL_CLOSE, ROOM_CLOSED, + USER_UNAVAILABLE, + USER_REJECTED, + SIP_TRUNK_FAILURE, } /** @@ -290,6 +293,9 @@ fun LivekitModels.DisconnectReason?.convert(): DisconnectReason { LivekitModels.DisconnectReason.MIGRATION -> DisconnectReason.MIGRATION LivekitModels.DisconnectReason.SIGNAL_CLOSE -> DisconnectReason.SIGNAL_CLOSE LivekitModels.DisconnectReason.ROOM_CLOSED -> DisconnectReason.ROOM_CLOSED + LivekitModels.DisconnectReason.USER_UNAVAILABLE -> DisconnectReason.USER_UNAVAILABLE + LivekitModels.DisconnectReason.USER_REJECTED -> DisconnectReason.USER_REJECTED + LivekitModels.DisconnectReason.SIP_TRUNK_FAILURE -> DisconnectReason.SIP_TRUNK_FAILURE LivekitModels.DisconnectReason.UNKNOWN_REASON, LivekitModels.DisconnectReason.UNRECOGNIZED, null, diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/room/RTCEngine.kt b/livekit-android-sdk/src/main/java/io/livekit/android/room/RTCEngine.kt index 6171c25dd..831caf12b 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/room/RTCEngine.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/room/RTCEngine.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 LiveKit, Inc. + * Copyright 2023-2025 LiveKit, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package io.livekit.android.room import android.os.SystemClock import androidx.annotation.VisibleForTesting import com.google.protobuf.ByteString +import com.vdurmont.semver4j.Semver import io.livekit.android.ConnectOptions import io.livekit.android.RoomOptions import io.livekit.android.dagger.InjectionNames @@ -148,6 +149,9 @@ internal constructor( private var lastRoomOptions: RoomOptions? = null private var participantSid: String? = null + internal val serverVersion: Semver? + get() = client.serverVersion + private val publisherObserver = PublisherTransportObserver(this, client) private val subscriberObserver = SubscriberTransportObserver(this, client) @@ -777,6 +781,7 @@ internal constructor( fun onLocalTrackUnpublished(trackUnpublished: LivekitRtc.TrackUnpublishedResponse) fun onTranscriptionReceived(transcription: LivekitModels.Transcription) fun onLocalTrackSubscribed(trackSubscribed: LivekitRtc.TrackSubscribed) + fun onRpcPacketReceived(dp: LivekitModels.DataPacket) } companion object { @@ -792,7 +797,7 @@ internal constructor( */ @VisibleForTesting const val LOSSY_DATA_CHANNEL_LABEL = "_lossy" - internal const val MAX_DATA_PACKET_SIZE = 15000 + internal const val MAX_DATA_PACKET_SIZE = 15360 // 15 KB private const val MAX_RECONNECT_RETRIES = 10 private const val MAX_RECONNECT_TIMEOUT = 60 * 1000 private const val MAX_ICE_CONNECT_TIMEOUT_MS = 20000 @@ -1040,13 +1045,21 @@ internal constructor( LivekitModels.DataPacket.ValueCase.RPC_ACK, LivekitModels.DataPacket.ValueCase.RPC_RESPONSE, -> { - // TODO + listener?.onRpcPacketReceived(dp) } LivekitModels.DataPacket.ValueCase.VALUE_NOT_SET, null, -> { LKLog.v { "invalid value for data packet" } } + + LivekitModels.DataPacket.ValueCase.STREAM_HEADER -> { + // TODO + } + + LivekitModels.DataPacket.ValueCase.STREAM_CHUNK -> { + // TODO + } } } diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/room/Room.kt b/livekit-android-sdk/src/main/java/io/livekit/android/room/Room.kt index d7bd703f8..33095be8b 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/room/Room.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/room/Room.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 LiveKit, Inc. + * Copyright 2023-2025 LiveKit, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -648,6 +648,8 @@ constructor( mutableRemoteParticipants = newParticipants eventBus.postEvent(RoomEvent.ParticipantDisconnected(this, removedParticipant), coroutineScope) + + localParticipant.handleParticipantDisconnect(identity) } fun getParticipantBySid(sid: String): Participant? { @@ -1195,6 +1197,10 @@ constructor( publication?.onTranscriptionReceived(event) } + override fun onRpcPacketReceived(dp: LivekitModels.DataPacket) { + localParticipant.handleDataPacket(dp) + } + /** * @suppress */ diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/room/SignalClient.kt b/livekit-android-sdk/src/main/java/io/livekit/android/room/SignalClient.kt index 02d753fba..768f422f9 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/room/SignalClient.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/room/SignalClient.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 LiveKit, Inc. + * Copyright 2023-2025 LiveKit, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -84,7 +84,7 @@ constructor( private var currentWs: WebSocket? = null private var isReconnecting: Boolean = false var listener: Listener? = null - private var serverVersion: Semver? = null + internal var serverVersion: Semver? = null private var lastUrl: String? = null private var lastOptions: ConnectOptions? = null private var lastRoomOptions: RoomOptions? = null @@ -841,6 +841,7 @@ constructor( lastUrl = null lastOptions = null lastRoomOptions = null + serverVersion = null } interface Listener { diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/room/participant/LocalParticipant.kt b/livekit-android-sdk/src/main/java/io/livekit/android/room/participant/LocalParticipant.kt index 43b5c9dad..8a9f3b9e0 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/room/participant/LocalParticipant.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/room/participant/LocalParticipant.kt @@ -21,6 +21,7 @@ import android.content.Context import android.content.Intent import androidx.annotation.VisibleForTesting import com.google.protobuf.ByteString +import com.vdurmont.semver4j.Semver import dagger.assisted.Assisted import dagger.assisted.AssistedFactory import dagger.assisted.AssistedInject @@ -47,15 +48,21 @@ import io.livekit.android.room.track.VideoCaptureParameter import io.livekit.android.room.track.VideoCodec import io.livekit.android.room.track.VideoEncoding import io.livekit.android.room.util.EncodingUtils +import io.livekit.android.rpc.RpcError import io.livekit.android.util.LKLog +import io.livekit.android.util.byteLength import io.livekit.android.util.flow import io.livekit.android.webrtc.sortVideoCodecPreferences import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.Job +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.delay import kotlinx.coroutines.launch +import kotlinx.coroutines.suspendCancellableCoroutine import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import livekit.LivekitModels +import livekit.LivekitModels.DataPacket import livekit.LivekitRtc import livekit.LivekitRtc.AddTrackRequest import livekit.LivekitRtc.SimulcastCodec @@ -67,9 +74,15 @@ import livekit.org.webrtc.RtpTransceiver.RtpTransceiverInit import livekit.org.webrtc.SurfaceTextureHelper import livekit.org.webrtc.VideoCapturer import livekit.org.webrtc.VideoProcessor +import java.util.Collections +import java.util.UUID import javax.inject.Named +import kotlin.coroutines.resume import kotlin.math.max import kotlin.math.min +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds class LocalParticipant @AssistedInject @@ -105,6 +118,10 @@ internal constructor( private val jobs = mutableMapOf() + private val rpcHandlers = Collections.synchronizedMap(mutableMapOf()) // methodName to handler + private val pendingAcks = Collections.synchronizedMap(mutableMapOf()) // requestId to pending ack + private val pendingResponses = Collections.synchronizedMap(mutableMapOf()) // requestId to pending response + // For ensuring that only one caller can execute setTrackEnabled at a time. // Without it, there's a potential to create multiple of the same source, // Camera has deadlock issues with multiple CameraCapturers trying to activate/stop. @@ -714,8 +731,8 @@ internal constructor( } val kind = when (reliability) { - DataPublishReliability.RELIABLE -> LivekitModels.DataPacket.Kind.RELIABLE - DataPublishReliability.LOSSY -> LivekitModels.DataPacket.Kind.LOSSY + DataPublishReliability.RELIABLE -> DataPacket.Kind.RELIABLE + DataPublishReliability.LOSSY -> DataPacket.Kind.LOSSY } val packetBuilder = LivekitModels.UserPacket.newBuilder().apply { payload = ByteString.copyFrom(data) @@ -727,7 +744,7 @@ internal constructor( addAllDestinationIdentities(identities.map { it.value }) } } - val dataPacket = LivekitModels.DataPacket.newBuilder() + val dataPacket = DataPacket.newBuilder() .setUser(packetBuilder) .setKind(kind) .build() @@ -741,9 +758,8 @@ internal constructor( * SipDTMF message using the provided code and digit, then encapsulates it * in a DataPacket before sending it via the engine. * - * Parameters: - * - code: an integer representing the DTMF signal code - * - digit: the string representing the DTMF digit (e.g., "1", "#", "*") + * @param code an integer representing the DTMF signal code + * @param digit the string representing the DTMF digit (e.g., "1", "#", "*") */ @Suppress("unused") @@ -763,6 +779,375 @@ internal constructor( engine.sendData(dataPacket) } + /** + * Establishes the participant as a receiver for calls of the specified RPC method. + * Will overwrite any existing callback for the same method. + * + * Example: + * ```kt + * room.localParticipant.registerRpcMethod("greet") { (requestId, callerIdentity, payload, responseTimeout) -> + * Log.i("TAG", "Received greeting from ${callerIdentity}: ${payload}") + * + * // Return a string + * "Hello, ${callerIdentity}!" + * } + * ``` + * + * The handler receives an [RpcInvocationData] with the following parameters: + * - `requestId`: A unique identifier for this RPC request + * - `callerIdentity`: The identity of the RemoteParticipant who initiated the RPC call + * - `payload`: The data sent by the caller (as a string) + * - `responseTimeout`: The maximum time available to return a response + * + * The handler should return a string. + * If unable to respond within [RpcInvocationData.responseTimeout], the request will result in an error on the caller's side. + * + * You may throw errors of type [RpcError] with a string `message` in the handler, + * and they will be received on the caller's side with the message intact. + * Other errors thrown in your handler will not be transmitted as-is, and will instead arrive to the caller as `1500` ("Application Error"). + * + * @param method The name of the indicated RPC method + * @param handler Will be invoked when an RPC request for this method is received + * @see RpcHandler + * @see RpcInvocationData + * @see performRpc + */ + @Suppress("RedundantSuspendModifier") + suspend fun registerRpcMethod( + method: String, + handler: RpcHandler, + ) { + this.rpcHandlers[method] = handler + } + + /** + * Unregisters a previously registered RPC method. + * + * @param method The name of the RPC method to unregister + */ + fun unregisterRpcMethod( + method: String, + ) { + this.rpcHandlers.remove(method) + } + + internal fun handleDataPacket(packet: DataPacket) { + when { + packet.hasRpcRequest() -> { + val rpcRequest = packet.rpcRequest + scope.launch { + handleIncomingRpcRequest( + callerIdentity = Identity(packet.participantIdentity), + requestId = rpcRequest.id, + method = rpcRequest.method, + payload = rpcRequest.payload, + responseTimeout = rpcRequest.responseTimeoutMs.toUInt().toLong().milliseconds, + version = rpcRequest.version, + ) + } + } + + packet.hasRpcResponse() -> { + val rpcResponse = packet.rpcResponse + var payload: String? = null + var error: RpcError? = null + + if (rpcResponse.hasPayload()) { + payload = rpcResponse.payload + } else if (rpcResponse.hasError()) { + error = RpcError.fromProto(rpcResponse.error) + } + handleIncomingRpcResponse( + requestId = rpcResponse.requestId, + payload = payload, + error = error, + ) + } + + packet.hasRpcAck() -> { + val rpcAck = packet.rpcAck + handleIncomingRpcAck(rpcAck.requestId) + } + } + } + + /** + * Initiate an RPC call to a remote participant + * @param destinationIdentity The identity of the destination participant. + * @param method The method name to call. + * @param payload The payload to pass to the method. + * @param responseTimeout Timeout for receiving a response after initial connection. + * Defaults to 10000. Max value of UInt.MAX_VALUE milliseconds. + * @return The response payload. + * @throws RpcError on failure. Details in [RpcError.message]. + */ + suspend fun performRpc( + destinationIdentity: Identity, + method: String, + payload: String, + responseTimeout: Duration = 10.seconds, + ): String = coroutineScope { + val maxRoundTripLatency = 2.seconds + + if (payload.byteLength() > RTCEngine.MAX_DATA_PACKET_SIZE) { + throw RpcError.BuiltinRpcError.REQUEST_PAYLOAD_TOO_LARGE.create() + } + + val serverVersion = engine.serverVersion + ?: throw RpcError.BuiltinRpcError.SEND_FAILED.create(data = "Not connected.") + + if (serverVersion < Semver("1.8.0")) { + throw RpcError.BuiltinRpcError.UNSUPPORTED_SERVER.create() + } + + val requestId = UUID.randomUUID().toString() + + publishRpcRequest( + destinationIdentity = destinationIdentity, + requestId = requestId, + method = method, + payload = payload, + responseTimeout = responseTimeout - maxRoundTripLatency, + ) + + val responsePayload = suspendCancellableCoroutine { continuation -> + var ackTimeoutJob: Job? = null + var responseTimeoutJob: Job? = null + + fun cleanup() { + ackTimeoutJob?.cancel() + responseTimeoutJob?.cancel() + pendingAcks.remove(requestId) + pendingResponses.remove(requestId) + } + + continuation.invokeOnCancellation { cleanup() } + + ackTimeoutJob = launch { + delay(maxRoundTripLatency) + val receivedAck = pendingAcks.remove(requestId) == null + if (!receivedAck) { + pendingResponses.remove(requestId) + continuation.cancel(RpcError.BuiltinRpcError.CONNECTION_TIMEOUT.create()) + } + } + pendingAcks[requestId] = PendingRpcAck( + participantIdentity = destinationIdentity, + onResolve = { ackTimeoutJob.cancel() }, + ) + + responseTimeoutJob = launch { + delay(responseTimeout) + val receivedResponse = pendingResponses.remove(requestId) == null + if (!receivedResponse) { + continuation.cancel(RpcError.BuiltinRpcError.RESPONSE_TIMEOUT.create()) + } + } + + pendingResponses[requestId] = PendingRpcResponse( + participantIdentity = destinationIdentity, + onResolve = { payload, error -> + if (pendingAcks.containsKey(requestId)) { + LKLog.i { "RPC response received before ack, id: $requestId" } + } + cleanup() + + if (error != null) { + continuation.cancel(error) + } else { + continuation.resume(payload ?: "") + } + }, + ) + } + return@coroutineScope responsePayload + } + + private suspend fun publishRpcRequest( + destinationIdentity: Identity, + requestId: String, + method: String, + payload: String, + responseTimeout: Duration = 10.seconds, + ) { + if (payload.byteLength() > RTCEngine.MAX_DATA_PACKET_SIZE) { + throw IllegalArgumentException("cannot publish data larger than " + RTCEngine.MAX_DATA_PACKET_SIZE) + } + + val dataPacket = with(DataPacket.newBuilder()) { + addDestinationIdentities(destinationIdentity.value) + kind = DataPacket.Kind.RELIABLE + rpcRequest = with(LivekitModels.RpcRequest.newBuilder()) { + this.id = requestId + this.method = method + this.payload = payload + this.responseTimeoutMs = responseTimeout.inWholeMilliseconds.toUInt().toInt() + build() + } + build() + } + + engine.sendData(dataPacket) + } + + private suspend fun publishRpcResponse( + destinationIdentity: Identity, + requestId: String, + payload: String?, + error: RpcError?, + ) { + if (payload.byteLength() > RTCEngine.MAX_DATA_PACKET_SIZE) { + throw IllegalArgumentException("cannot publish data larger than " + RTCEngine.MAX_DATA_PACKET_SIZE) + } + + val dataPacket = with(DataPacket.newBuilder()) { + addDestinationIdentities(destinationIdentity.value) + kind = DataPacket.Kind.RELIABLE + rpcResponse = with(LivekitModels.RpcResponse.newBuilder()) { + this.requestId = requestId + if (error != null) { + this.error = error.toProto() + } else { + this.payload = payload ?: "" + } + build() + } + build() + } + + engine.sendData(dataPacket) + } + + private suspend fun publishRpcAck( + destinationIdentity: Identity, + requestId: String, + ) { + val dataPacket = with(DataPacket.newBuilder()) { + addDestinationIdentities(destinationIdentity.value) + kind = DataPacket.Kind.RELIABLE + rpcAck = with(LivekitModels.RpcAck.newBuilder()) { + this.requestId = requestId + build() + } + build() + } + + engine.sendData(dataPacket) + } + + private fun handleIncomingRpcAck(requestId: String) { + val handler = this.pendingAcks.remove(requestId) + if (handler != null) { + handler.onResolve() + } else { + LKLog.e { "Ack received for unexpected RPC request, id = $requestId" } + } + } + + private fun handleIncomingRpcResponse( + requestId: String, + payload: String?, + error: RpcError?, + ) { + val handler = this.pendingResponses.remove(requestId) + if (handler != null) { + handler.onResolve(payload, error) + } else { + LKLog.e { "Response received for unexpected RPC request, id = $requestId" } + } + } + + private suspend fun handleIncomingRpcRequest( + callerIdentity: Identity, + requestId: String, + method: String, + payload: String, + responseTimeout: Duration, + version: Int, + ) { + publishRpcAck(callerIdentity, requestId) + + if (version != 1) { + publishRpcResponse( + destinationIdentity = callerIdentity, + requestId = requestId, + payload = null, + error = RpcError.BuiltinRpcError.UNSUPPORTED_VERSION.create(), + ) + return + } + + val handler = this.rpcHandlers[method] + + if (handler == null) { + publishRpcResponse( + destinationIdentity = callerIdentity, + requestId = requestId, + payload = null, + error = RpcError.BuiltinRpcError.UNSUPPORTED_METHOD.create(), + ) + return + } + + var responseError: RpcError? = null + var responsePayload: String? = null + + try { + val response = handler.invoke( + RpcInvocationData( + requestId = requestId, + callerIdentity = callerIdentity, + payload = payload, + responseTimeout = responseTimeout, + ), + ) + + if (response.byteLength() > RTCEngine.MAX_DATA_PACKET_SIZE) { + responseError = RpcError.BuiltinRpcError.RESPONSE_PAYLOAD_TOO_LARGE.create() + LKLog.w { "RPC Response payload too large for $method" } + } else { + responsePayload = response + } + } catch (e: Exception) { + if (e is RpcError) { + responseError = e + } else { + LKLog.w(e) { "Uncaught error returned by RPC handler for $method. Returning APPLICATION_ERROR instead." } + responseError = RpcError.BuiltinRpcError.APPLICATION_ERROR.create() + } + } + + publishRpcResponse( + destinationIdentity = callerIdentity, + requestId = requestId, + payload = responsePayload, + error = responseError, + ) + } + + internal fun handleParticipantDisconnect(identity: Identity) { + synchronized(pendingAcks) { + val acksIterator = pendingAcks.iterator() + while (acksIterator.hasNext()) { + val (_, ack) = acksIterator.next() + if (ack.participantIdentity == identity) { + acksIterator.remove() + } + } + } + + synchronized(pendingResponses) { + val responsesIterator = pendingResponses.iterator() + while (responsesIterator.hasNext()) { + val (_, response) = responsesIterator.next() + if (response.participantIdentity == identity) { + responsesIterator.remove() + response.onResolve(null, RpcError.BuiltinRpcError.RECIPIENT_DISCONNECTED.create()) + } + } + } + } + /** * @suppress */ @@ -1232,3 +1617,42 @@ internal fun VideoTrackPublishOptions.hasBackupCodec(): Boolean { private val backupCodecs = listOf(VideoCodec.VP8.codecName, VideoCodec.H264.codecName) private fun isBackupCodec(codecName: String) = backupCodecs.contains(codecName) + +/** + * A handler that processes an RPC request and returns a string + * that will be sent back to the requester. + * + * Throwing an [RpcError] will send the error back to the requester. + * + * @see [LocalParticipant.registerRpcMethod] + */ +typealias RpcHandler = suspend (RpcInvocationData) -> String + +data class RpcInvocationData( + /** + * A unique identifier for this RPC request + */ + val requestId: String, + /** + * The identity of the RemoteParticipant who initiated the RPC call + */ + val callerIdentity: Participant.Identity, + /** + * The data sent by the caller (as a string) + */ + val payload: String, + /** + * The maximum time available to return a response + */ + val responseTimeout: Duration, +) + +private data class PendingRpcAck( + val onResolve: () -> Unit, + val participantIdentity: Participant.Identity, +) + +private data class PendingRpcResponse( + val onResolve: (payload: String?, error: RpcError?) -> Unit, + val participantIdentity: Participant.Identity, +) diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/rpc/RpcError.kt b/livekit-android-sdk/src/main/java/io/livekit/android/rpc/RpcError.kt new file mode 100644 index 000000000..62492f153 --- /dev/null +++ b/livekit-android-sdk/src/main/java/io/livekit/android/rpc/RpcError.kt @@ -0,0 +1,89 @@ +/* + * Copyright 2025 LiveKit, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.livekit.android.rpc + +import io.livekit.android.room.RTCEngine +import io.livekit.android.util.truncateBytes +import livekit.LivekitModels + +/** + * Specialized error handling for RPC methods. + * + * Instances of this type, when thrown in a RPC method handler, will have their [message] + * serialized and sent across the wire. The sender will receive an equivalent error on the other side. + * + * Built-in types are included but developers may use any message string, with a max length of 256 bytes. + */ +data class RpcError( + /** + * The error code of the RPC call. Error codes 1001-1999 are reserved for built-in errors. + * + * See [RpcError.BuiltinRpcError] for built-in error information. + */ + val code: Int, + + /** + * A message to include. Strings over 256 bytes will be truncated. + */ + override val message: String, + /** + * An optional data payload. Must be smaller than 15KB in size, or else will be truncated. + */ + val data: String = "", +) : Exception(message) { + + enum class BuiltinRpcError(val code: Int, val message: String) { + APPLICATION_ERROR(1500, "Application error in method handler"), + CONNECTION_TIMEOUT(1501, "Connection timeout"), + RESPONSE_TIMEOUT(1502, "Response timeout"), + RECIPIENT_DISCONNECTED(1503, "Recipient disconnected"), + RESPONSE_PAYLOAD_TOO_LARGE(1504, "Response payload too large"), + SEND_FAILED(1505, "Failed to send"), + + UNSUPPORTED_METHOD(1400, "Method not supported at destination"), + RECIPIENT_NOT_FOUND(1401, "Recipient not found"), + REQUEST_PAYLOAD_TOO_LARGE(1402, "Request payload too large"), + UNSUPPORTED_SERVER(1403, "RPC not supported by server"), + UNSUPPORTED_VERSION(1404, "Unsupported RPC version"), + ; + + fun create(data: String = ""): RpcError { + return RpcError(code, message, data) + } + } + + companion object { + const val MAX_MESSAGE_BYTES = 256 + + fun fromProto(proto: LivekitModels.RpcError): RpcError { + return RpcError( + code = proto.code, + message = (proto.message ?: "").truncateBytes(MAX_MESSAGE_BYTES), + data = proto.data.truncateBytes(RTCEngine.MAX_DATA_PACKET_SIZE), + ) + } + } + + fun toProto(): LivekitModels.RpcError { + return with(LivekitModels.RpcError.newBuilder()) { + this.code = this@RpcError.code + this.message = this@RpcError.message + this.data = this@RpcError.data + build() + } + } +} diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/util/StringByteUtils.kt b/livekit-android-sdk/src/main/java/io/livekit/android/util/StringByteUtils.kt new file mode 100644 index 000000000..778a46b53 --- /dev/null +++ b/livekit-android-sdk/src/main/java/io/livekit/android/util/StringByteUtils.kt @@ -0,0 +1,47 @@ +/* + * Copyright 2025 LiveKit, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.livekit.android.util + +import okio.ByteString.Companion.encode + +internal fun String?.byteLength(): Int { + if (this == null) { + return 0 + } + return this.encode(Charsets.UTF_8).size +} + +internal fun String.truncateBytes(maxBytes: Int): String { + if (this.byteLength() <= maxBytes) { + return this + } + + var low = 0 + var high = length + + // Binary search for string that fits. + while (low < high) { + val mid = (low + high + 1) / 2 + if (this.substring(0, mid).byteLength() <= maxBytes) { + low = mid + } else { + high = mid - 1 + } + } + + return substring(0, low) +} diff --git a/livekit-android-test/src/main/java/io/livekit/android/test/MockE2ETest.kt b/livekit-android-test/src/main/java/io/livekit/android/test/MockE2ETest.kt index 97252fb76..c85312ee1 100644 --- a/livekit-android-test/src/main/java/io/livekit/android/test/MockE2ETest.kt +++ b/livekit-android-test/src/main/java/io/livekit/android/test/MockE2ETest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 LiveKit, Inc. + * Copyright 2023-2025 LiveKit, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -69,7 +69,7 @@ abstract class MockE2ETest : BaseTest() { room.release() } - suspend fun connect(joinResponse: LivekitRtc.SignalResponse = TestData.JOIN) { + open suspend fun connect(joinResponse: LivekitRtc.SignalResponse = TestData.JOIN) { connectSignal(joinResponse) connectPeerConnection() } diff --git a/livekit-android-test/src/main/java/io/livekit/android/test/mock/MockDataChannel.kt b/livekit-android-test/src/main/java/io/livekit/android/test/mock/MockDataChannel.kt index 1899c8cce..2fecc8726 100644 --- a/livekit-android-test/src/main/java/io/livekit/android/test/mock/MockDataChannel.kt +++ b/livekit-android-test/src/main/java/io/livekit/android/test/mock/MockDataChannel.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 LiveKit, Inc. + * Copyright 2023-2025 LiveKit, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import livekit.org.webrtc.DataChannel class MockDataChannel(private val label: String?) : DataChannel(1L) { var observer: Observer? = null - var sentBuffers = mutableListOf() + var sentBuffers = mutableListOf() override fun registerObserver(observer: Observer?) { this.observer = observer } @@ -46,7 +46,7 @@ class MockDataChannel(private val label: String?) : DataChannel(1L) { return 0 } - override fun send(buffer: Buffer?): Boolean { + override fun send(buffer: Buffer): Boolean { sentBuffers.add(buffer) return true } @@ -56,4 +56,8 @@ class MockDataChannel(private val label: String?) : DataChannel(1L) { override fun dispose() { } + + fun simulateBufferReceived(buffer: Buffer) { + observer?.onMessage(buffer) + } } diff --git a/livekit-android-test/src/main/java/io/livekit/android/test/mock/TestData.kt b/livekit-android-test/src/main/java/io/livekit/android/test/mock/TestData.kt index bdeaafd42..5869969ff 100644 --- a/livekit-android-test/src/main/java/io/livekit/android/test/mock/TestData.kt +++ b/livekit-android-test/src/main/java/io/livekit/android/test/mock/TestData.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 LiveKit, Inc. + * Copyright 2023-2025 LiveKit, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package io.livekit.android.test.mock import livekit.LivekitModels import livekit.LivekitRtc +import java.util.UUID object TestData { @@ -110,7 +111,7 @@ object TestData { build() }, ) - serverVersion = "0.15.2" + serverVersion = "1.8.0" build() } build() @@ -327,4 +328,16 @@ object TestData { } build() } + val DATA_PACKET_RPC_REQUEST = with(LivekitModels.DataPacket.newBuilder()) { + participantIdentity = REMOTE_PARTICIPANT.identity + rpcRequest = with(LivekitModels.RpcRequest.newBuilder()) { + id = UUID.randomUUID().toString() + method = "hello" + payload = "hello world" + responseTimeoutMs = 10000 + version = 1 + build() + } + build() + } } diff --git a/livekit-android-test/src/test/java/io/livekit/android/rpc/RpcMockE2ETest.kt b/livekit-android-test/src/test/java/io/livekit/android/rpc/RpcMockE2ETest.kt new file mode 100644 index 000000000..446f75ab7 --- /dev/null +++ b/livekit-android-test/src/test/java/io/livekit/android/rpc/RpcMockE2ETest.kt @@ -0,0 +1,377 @@ +/* + * Copyright 2023-2025 LiveKit, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.livekit.android.room.participant + +import com.google.protobuf.ByteString +import io.livekit.android.room.RTCEngine +import io.livekit.android.rpc.RpcError +import io.livekit.android.test.MockE2ETest +import io.livekit.android.test.mock.MockDataChannel +import io.livekit.android.test.mock.MockPeerConnection +import io.livekit.android.test.mock.TestData +import io.livekit.android.test.mock.TestData.REMOTE_PARTICIPANT +import io.livekit.android.test.util.toDataChannelBuffer +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.async +import kotlinx.coroutines.launch +import livekit.LivekitModels +import livekit.LivekitRtc +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +import kotlin.time.Duration.Companion.milliseconds + +@ExperimentalCoroutinesApi +@RunWith(RobolectricTestRunner::class) +class RpcMockE2ETest : MockE2ETest() { + + lateinit var pubDataChannel: MockDataChannel + lateinit var subDataChannel: MockDataChannel + + companion object { + val ERROR = RpcError( + 1, + "This is an error message.", + "This is an error payload.", + ) + } + + override suspend fun connect(joinResponse: LivekitRtc.SignalResponse) { + super.connect(joinResponse) + + val pubPeerConnection = component.rtcEngine().getPublisherPeerConnection() as MockPeerConnection + pubDataChannel = pubPeerConnection.dataChannels[RTCEngine.RELIABLE_DATA_CHANNEL_LABEL] as MockDataChannel + + val subPeerConnection = component.rtcEngine().getSubscriberPeerConnection() as MockPeerConnection + subDataChannel = MockDataChannel(RTCEngine.RELIABLE_DATA_CHANNEL_LABEL) + subPeerConnection.observer?.onDataChannel(subDataChannel) + } + + private fun createAck(requestId: String) = + with(LivekitModels.DataPacket.newBuilder()) { + participantIdentity = REMOTE_PARTICIPANT.identity + rpcAck = with(LivekitModels.RpcAck.newBuilder()) { + this.requestId = requestId + build() + } + build() + }.toDataChannelBuffer() + + private fun createResponse(requestId: String, payload: String? = null, error: RpcError? = null) = with(LivekitModels.DataPacket.newBuilder()) { + participantIdentity = REMOTE_PARTICIPANT.identity + rpcResponse = with(LivekitModels.RpcResponse.newBuilder()) { + this.requestId = requestId + if (error != null) { + this.error = error.toProto() + } else if (payload != null) { + this.payload = payload + } + + build() + } + build() + }.toDataChannelBuffer() + + @Test + fun handleRpcRequest() = runTest { + connect() + + var methodCalled = false + room.localParticipant.registerRpcMethod("hello") { + methodCalled = true + "bye" + } + subDataChannel.simulateBufferReceived(TestData.DATA_PACKET_RPC_REQUEST.toDataChannelBuffer()) + assertTrue(methodCalled) + + coroutineRule.dispatcher.scheduler.advanceUntilIdle() + + // Check that ack and response were sent + val buffers = pubDataChannel.sentBuffers + assertEquals(2, buffers.size) + + val ackBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data)) + val responseBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[1].data)) + + assertTrue(ackBuffer.hasRpcAck()) + assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, ackBuffer.rpcAck.requestId) + + assertTrue(responseBuffer.hasRpcResponse()) + assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, responseBuffer.rpcResponse.requestId) + assertEquals("bye", responseBuffer.rpcResponse.payload) + } + + @Test + fun handleRpcRequestWithError() = runTest { + connect() + + var methodCalled = false + room.localParticipant.registerRpcMethod("hello") { + methodCalled = true + throw ERROR + } + subDataChannel.simulateBufferReceived(TestData.DATA_PACKET_RPC_REQUEST.toDataChannelBuffer()) + assertTrue(methodCalled) + + coroutineRule.dispatcher.scheduler.advanceUntilIdle() + + // Check that ack and response were sent + val buffers = pubDataChannel.sentBuffers + assertEquals(2, buffers.size) + + val ackBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data)) + val responseBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[1].data)) + + assertTrue(ackBuffer.hasRpcAck()) + assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, ackBuffer.rpcAck.requestId) + + assertTrue(responseBuffer.hasRpcResponse()) + assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, responseBuffer.rpcResponse.requestId) + assertEquals(ERROR, RpcError.fromProto(responseBuffer.rpcResponse.error)) + } + + @Test + fun handleRpcRequestWithNoHandler() = runTest { + connect() + + subDataChannel.simulateBufferReceived(TestData.DATA_PACKET_RPC_REQUEST.toDataChannelBuffer()) + + coroutineRule.dispatcher.scheduler.advanceUntilIdle() + + // Check that ack and response were sent + val buffers = pubDataChannel.sentBuffers + assertEquals(2, buffers.size) + + val ackBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data)) + val responseBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[1].data)) + + assertTrue(ackBuffer.hasRpcAck()) + assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, ackBuffer.rpcAck.requestId) + + assertTrue(responseBuffer.hasRpcResponse()) + assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, responseBuffer.rpcResponse.requestId) + assertEquals(RpcError.BuiltinRpcError.UNSUPPORTED_METHOD.create(), RpcError.fromProto(responseBuffer.rpcResponse.error)) + } + + @Test + fun performRpc() = runTest { + connect() + + val rpcJob = async { + room.localParticipant.performRpc( + destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity), + method = "hello", + payload = "hello world", + ) + } + + // Check that request was sent + val buffers = pubDataChannel.sentBuffers + assertEquals(1, buffers.size) + + val requestBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data)) + + assertTrue(requestBuffer.hasRpcRequest()) + assertEquals("hello", requestBuffer.rpcRequest.method) + assertEquals("hello world", requestBuffer.rpcRequest.payload) + + val requestId = requestBuffer.rpcRequest.id + + // receive ack and response + subDataChannel.simulateBufferReceived(createAck(requestId)) + subDataChannel.simulateBufferReceived(createResponse(requestId, payload = "bye")) + + coroutineRule.dispatcher.scheduler.advanceUntilIdle() + val response = rpcJob.await() + + assertEquals("bye", response) + } + + @Test + fun performRpcWithError() = runTest { + connect() + + val rpcJob = async { + var expectedError: Exception? = null + try { + room.localParticipant.performRpc( + destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity), + method = "hello", + payload = "hello world", + ) + } catch (e: Exception) { + expectedError = e + } + return@async expectedError + } + + val buffers = pubDataChannel.sentBuffers + val requestBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data)) + val requestId = requestBuffer.rpcRequest.id + + // receive ack and response + subDataChannel.simulateBufferReceived(createAck(requestId)) + subDataChannel.simulateBufferReceived(createResponse(requestId, error = ERROR)) + + coroutineRule.dispatcher.scheduler.advanceUntilIdle() + val receivedError = rpcJob.await() + + assertEquals(ERROR, receivedError) + } + + @Test + fun performRpcWithParticipantDisconnected() = runTest { + connect() + simulateMessageFromServer(TestData.PARTICIPANT_JOIN) + + val rpcJob = async { + var expectedError: Exception? = null + try { + room.localParticipant.performRpc( + destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity), + method = "hello", + payload = "hello world", + ) + } catch (e: Exception) { + expectedError = e + } + return@async expectedError + } + + simulateMessageFromServer(TestData.PARTICIPANT_DISCONNECT) + + coroutineRule.dispatcher.scheduler.advanceUntilIdle() + val error = rpcJob.await() + + assertEquals(RpcError.BuiltinRpcError.RECIPIENT_DISCONNECTED.create(), error) + } + + @Test + fun performRpcWithConnectionTimeoutError() = runTest { + connect() + + val rpcJob = async { + var expectedError: Exception? = null + try { + room.localParticipant.performRpc( + destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity), + method = "hello", + payload = "hello world", + ) + } catch (e: Exception) { + expectedError = e + } + return@async expectedError + } + + coroutineRule.dispatcher.scheduler.advanceTimeBy(3000) + + val error = rpcJob.await() + + assertEquals(RpcError.BuiltinRpcError.CONNECTION_TIMEOUT.create(), error) + } + + @Test + fun performRpcWithResponseTimeoutError() = runTest { + connect() + + val rpcJob = async { + var expectedError: Exception? = null + try { + room.localParticipant.performRpc( + destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity), + method = "hello", + payload = "hello world", + ) + } catch (e: Exception) { + expectedError = e + } + return@async expectedError + } + + val buffers = pubDataChannel.sentBuffers + val requestBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data)) + val requestId = requestBuffer.rpcRequest.id + + // receive ack only + subDataChannel.simulateBufferReceived(createAck(requestId)) + + coroutineRule.dispatcher.scheduler.advanceTimeBy(15000) + + val error = rpcJob.await() + + assertEquals(RpcError.BuiltinRpcError.RESPONSE_TIMEOUT.create(), error) + } + + @Test + fun uintMaxValueVerification() = runTest { + assertEquals(4_294_967_295L, UInt.MAX_VALUE.toLong()) + } + + /** + * Protobuf handles UInt32 as Java signed integers. + * This test verifies whether our conversion is properly sent over the wire. + */ + @Test + fun performRpcProtoUIntVerification() = runTest { + connect() + val rpcJob = launch { + room.localParticipant.performRpc( + destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity), + method = "hello", + payload = "hello world", + responseTimeout = UInt.MAX_VALUE.toLong().milliseconds, + ) + } + + val buffers = pubDataChannel.sentBuffers + val requestBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data)) + + val expectedResponseTimeout = UInt.MAX_VALUE - 2000u // 2000 comes from maxRoundTripLatency + val responseTimeout = requestBuffer.rpcRequest.responseTimeoutMs.toUInt() + assertEquals(expectedResponseTimeout, responseTimeout) + rpcJob.cancel() + } + + /** + * Protobuf handles UInt32 as Java signed integers. + * This test verifies whether our conversion is properly sent over the wire. + */ + @Test + fun handleRpcProtoUIntVerification() = runTest { + connect() + + var methodCalled = false + room.localParticipant.registerRpcMethod("hello") { invocationData -> + assertEquals(4_294_967_295L, invocationData.responseTimeout.inWholeMilliseconds) + methodCalled = true + "bye" + } + subDataChannel.simulateBufferReceived( + with(TestData.DATA_PACKET_RPC_REQUEST.toBuilder()) { + rpcRequest = with(rpcRequest.toBuilder()) { + responseTimeoutMs = UInt.MAX_VALUE.toInt() + build() + } + build() + }.toDataChannelBuffer(), + ) + assertTrue(methodCalled) + } +} diff --git a/protocol b/protocol index a601adc5e..9e8d1e37c 160000 --- a/protocol +++ b/protocol @@ -1 +1 @@ -Subproject commit a601adc5e9027820857a6d445b32a868b19d4184 +Subproject commit 9e8d1e37c5eb4434424bc16c657c83e7dc63bc2a