Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: wire cells sdk integration (#WPB-15743) #3281

Open
wants to merge 10 commits into
base: epic/wire-cells
Choose a base branch
from
1 change: 1 addition & 0 deletions cells/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/build
86 changes: 86 additions & 0 deletions cells/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
plugins {
id(libs.plugins.android.library.get().pluginId)
id(libs.plugins.kotlin.multiplatform.get().pluginId)
alias(libs.plugins.ksp)
id(libs.plugins.kalium.library.get().pluginId)
}

kaliumLibrary {
multiplatform { enableJs.set(false) }
}

kotlin {
explicitApi()
sourceSets {
val commonMain by getting {
dependencies {
implementation(project(":common"))
implementation(project(":network"))
implementation(project(":data"))
implementation(project(":util"))
implementation(project(":persistence"))
implementation(libs.coroutines.core)
implementation(libs.ktor.authClient)
implementation(libs.okio.core)
implementation(libs.benAsherUUID)
implementation(libs.wire.cells.sdk)
}
}
val commonTest by getting {
dependencies {
// coroutines
implementation(libs.coroutines.test)
implementation(libs.turbine)
// ktor test
implementation(libs.ktor.mock)
// mocks
implementation(libs.mockative.runtime)
implementation(libs.okio.test)
}
}

fun org.jetbrains.kotlin.gradle.plugin.KotlinSourceSet.addCommonKotlinJvmSourceDir() {
kotlin.srcDir("src/commonJvmAndroid/kotlin")
}

val jvmMain by getting {
addCommonKotlinJvmSourceDir()
dependencies {
implementation(libs.ktor.okHttp)
implementation(awssdk.services.s3)
}
}
val androidMain by getting {
addCommonKotlinJvmSourceDir()
dependencies {
implementation(libs.ktor.okHttp)
implementation(awssdk.services.s3)
}
}
}
}

dependencies {
configurations
.filter { it.name.startsWith("ksp") && it.name.contains("Test") }
.forEach {
add(it.name, libs.mockative.processor)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Wire
* Copyright (C) 2025 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.cells.data

import aws.smithy.kotlin.runtime.client.ProtocolRequestInterceptorContext
import aws.smithy.kotlin.runtime.client.ProtocolResponseInterceptorContext
import aws.smithy.kotlin.runtime.http.HttpBody
import aws.smithy.kotlin.runtime.http.interceptors.HttpInterceptor
import aws.smithy.kotlin.runtime.http.request.HttpRequest
import aws.smithy.kotlin.runtime.http.request.toBuilder
import aws.smithy.kotlin.runtime.http.response.HttpResponse
import aws.smithy.kotlin.runtime.io.SdkBuffer
import aws.smithy.kotlin.runtime.io.SdkByteReadChannel
import aws.smithy.kotlin.runtime.io.SdkSource
import aws.smithy.kotlin.runtime.io.readAll

internal open class AwsProgressListenerInterceptor(
private val progressListener: (Long) -> Unit
) : HttpInterceptor {
fun convertBodyWithProgressUpdates(httpBody: HttpBody): HttpBody {
return when (httpBody) {
is HttpBody.ChannelContent -> {
SdkByteReadChannelWithProgressUpdates(
httpBody,
progressListener
)
}
is HttpBody.SourceContent -> {
SourceContentWithProgressUpdates(
httpBody,
progressListener
)
}
is HttpBody.Bytes -> {
httpBody
}
is HttpBody.Empty -> {
httpBody
}
}
}

internal class SourceContentWithProgressUpdates(
private val sourceContent: SourceContent,
private val progressListener: (Long) -> Unit
) : HttpBody.SourceContent() {
private val delegate = sourceContent.readFrom()
private var uploaded = 0L
override val contentLength: Long?
get() = sourceContent.contentLength

override fun readFrom(): SdkSource {
return object : SdkSource {
override fun close() {
delegate.close()
}

override fun read(sink: SdkBuffer, limit: Long): Long {
return delegate.read(sink, limit).also {
if (it > 0) {
uploaded += it
progressListener(uploaded)
}
}
}
}
}
}

internal class SdkByteReadChannelWithProgressUpdates(
private val httpBody: ChannelContent,
private val progressListener: (Long) -> Unit
) : HttpBody.ChannelContent() {
val delegate = httpBody.readFrom()
private var uploaded = 0L
override val contentLength: Long?
get() = httpBody.contentLength
override fun readFrom(): SdkByteReadChannel {
return object : SdkByteReadChannel by delegate {
override val availableForRead: Int
get() = delegate.availableForRead

override val isClosedForRead: Boolean
get() = delegate.isClosedForRead

override val isClosedForWrite: Boolean
get() = delegate.isClosedForWrite

override fun cancel(cause: Throwable?): Boolean {
return delegate.cancel(cause)
}

override suspend fun read(sink: SdkBuffer, limit: Long): Long {
return delegate.readAll(sink).also {
if (it > 0) {
uploaded += it
progressListener(uploaded)
}
}
}
}
}
}

internal class DownloadProgressListenerInterceptor(
progressListener: (Long) -> Unit
) : AwsProgressListenerInterceptor(progressListener) {
override suspend fun modifyBeforeDeserialization(
context: ProtocolResponseInterceptorContext<Any, HttpRequest, HttpResponse>
): HttpResponse {
val body = convertBodyWithProgressUpdates(context.protocolResponse.body)
return HttpResponse(context.protocolResponse.status, context.protocolResponse.headers, body)
}
}

internal class UploadProgressListenerInterceptor(
progressListener: (Long) -> Unit
) : AwsProgressListenerInterceptor(progressListener) {
override suspend fun modifyBeforeTransmit(
context: ProtocolRequestInterceptorContext<Any, HttpRequest>
): HttpRequest {
val builder = context.protocolRequest.toBuilder()
builder.body = convertBodyWithProgressUpdates(builder.body)
return builder.build()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Wire
* Copyright (C) 2025 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.cells.data

import aws.sdk.kotlin.runtime.auth.credentials.StaticCredentialsProvider
import aws.sdk.kotlin.services.s3.S3Client
import aws.sdk.kotlin.services.s3.completeMultipartUpload
import aws.sdk.kotlin.services.s3.createMultipartUpload
import aws.sdk.kotlin.services.s3.model.CompletedMultipartUpload
import aws.sdk.kotlin.services.s3.model.CompletedPart
import aws.sdk.kotlin.services.s3.putObject
import aws.sdk.kotlin.services.s3.uploadPart
import aws.sdk.kotlin.services.s3.withConfig
import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials
import aws.smithy.kotlin.runtime.content.ByteStream
import aws.smithy.kotlin.runtime.content.asByteStream
import aws.smithy.kotlin.runtime.net.url.Url
import com.wire.kalium.cells.data.model.CellNodeDTO
import com.wire.kalium.cells.domain.model.CellsCredentials
import okhttp3.internal.http2.Header
import okio.Path
import java.io.RandomAccessFile
import java.nio.ByteBuffer

internal actual fun cellsAwsClient(credentials: CellsCredentials): CellsAwsClient = CellsAwsClientJvm(credentials)

private class CellsAwsClientJvm(
private val credentials: CellsCredentials
) : CellsAwsClient {

private companion object {
const val DEFAULT_REGION = "us-east-1"
const val DEFAULT_BUCKET_NAME = "io"
const val MAX_REGULAR_UPLOAD_SIZE = 100 * 1024 * 1024L
const val MULTIPART_CHUNK_SIZE = 10 * 1024 * 1024
}

private val s3Client: S3Client by lazy { buildS3Client() }

private fun buildS3Client() = with(credentials) {
S3Client {
region = DEFAULT_REGION
enableAwsChunked = false
credentialsProvider = StaticCredentialsProvider(
Credentials(
accessKeyId = accessToken,
secretAccessKey = gatewaySecret,
)
)
endpointUrl = Url.parse(serverUrl)
}
}

override suspend fun upload(path: Path, node: CellNodeDTO, onProgressUpdate: (Long) -> Unit) {
val length = path.toFile().length()
if (length > MAX_REGULAR_UPLOAD_SIZE) {
uploadMultipart(path, node, onProgressUpdate)
} else {
uploadRegular(path, node, onProgressUpdate)
}
}

private suspend fun uploadRegular(path: Path, node: CellNodeDTO, onProgressUpdate: (Long) -> Unit) {
withS3Client(uploadProgressListener = onProgressUpdate) {
putObject {
bucket = DEFAULT_BUCKET_NAME
key = node.path
metadata = node.createDraftNodeMetaData()
body = path.toFile().asByteStream()
}
}
}

private suspend fun uploadMultipart(path: Path, node: CellNodeDTO, onProgressUpdate: (Long) -> Unit) {
val buffer = ByteBuffer.allocate(MULTIPART_CHUNK_SIZE)
var number = 1
val completed = mutableListOf<CompletedPart>()
withS3Client {
val requestId = createMultipartUpload {
bucket = DEFAULT_BUCKET_NAME
key = node.path
metadata = node.createDraftNodeMetaData()
}.uploadId
RandomAccessFile(path.toFile(), "r").use { file ->
val fileSize = file.length()
var position = 0L
while (position < fileSize) {
file.seek(position)
val bytesRead = file.channel.read(buffer)
onProgressUpdate(position + bytesRead)
buffer.flip()
val partData = ByteArray(bytesRead)
buffer.get(partData, 0, bytesRead)
val response = uploadPart {
bucket = DEFAULT_BUCKET_NAME
key = node.path
uploadId = requestId
partNumber = number
contentLength = bytesRead.toLong()
body = ByteStream.fromBytes(partData)
}
completed.add(
CompletedPart {
partNumber = number
eTag = response.eTag
}
)
buffer.clear()
position += bytesRead
number++
}
}
completeMultipartUpload {
bucket = DEFAULT_BUCKET_NAME
key = node.path
uploadId = requestId
multipartUpload = CompletedMultipartUpload {
parts = completed
}
}
}
}

private suspend fun <T> withS3Client(
uploadProgressListener: ((Long) -> Unit)? = null,
downloadProgressListener: ((Long) -> Unit)? = null,
block: suspend S3Client.() -> T,
): T =
s3Client.withConfig {
if (uploadProgressListener != null) {
Header.TARGET_PATH
interceptors.add(AwsProgressListenerInterceptor.UploadProgressListenerInterceptor(uploadProgressListener))
}
if (downloadProgressListener != null) {
interceptors.add(AwsProgressListenerInterceptor.DownloadProgressListenerInterceptor(downloadProgressListener))
}
}.use {
block(it)
}
}

private fun CellNodeDTO.createDraftNodeMetaData() = mapOf(
MetadataHeaders.DRAFT_MODE to "true",
MetadataHeaders.CREATE_RESOURCE_UUID to uuid,
MetadataHeaders.CREATE_VERSION_ID to versionId,
)
Loading
Loading