Skip to content

Commit 32b678d

Browse files
authored
fix: handle CRT native connection resources more carefully (#198)
1 parent 1b9b566 commit 32b678d

File tree

10 files changed

+236
-239
lines changed

10 files changed

+236
-239
lines changed

aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/auth/signing/AwsSignerNative.kt

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,7 @@ package aws.sdk.kotlin.crt.auth.signing
88
import aws.sdk.kotlin.crt.*
99
import aws.sdk.kotlin.crt.auth.credentials.Credentials
1010
import aws.sdk.kotlin.crt.http.*
11-
import aws.sdk.kotlin.crt.util.asAwsByteCursor
12-
import aws.sdk.kotlin.crt.util.initFromCursor
13-
import aws.sdk.kotlin.crt.util.toAwsString
14-
import aws.sdk.kotlin.crt.util.toKString
15-
import aws.sdk.kotlin.crt.util.use
11+
import aws.sdk.kotlin.crt.util.*
1612
import kotlinx.cinterop.*
1713
import kotlinx.coroutines.channels.Channel
1814
import kotlinx.coroutines.runBlocking
@@ -223,15 +219,10 @@ private fun AwsSigningConfig.toNativeSigningConfig(): CPointer<aws_signing_confi
223219
private typealias ShouldSignHeaderFunction = (String) -> Boolean
224220
private fun nativeShouldSignHeaderFn(headerName: CPointer<aws_byte_cursor>?, userData: COpaquePointer?): Boolean {
225221
checkNotNull(headerName) { "aws_should_sign_header_fn expected non-null header name" }
226-
if (userData == null) {
227-
return true
228-
}
229-
230-
userData.asStableRef<ShouldSignHeaderFunction>().use {
231-
val kShouldSignHeaderFn = it.get()
222+
return userData?.withDereferenced<ShouldSignHeaderFunction, _>(dispose = true) { kShouldSignHeaderFn ->
232223
val kHeaderName = headerName.pointed.toKString()
233-
return kShouldSignHeaderFn(kHeaderName)
234-
}
224+
kShouldSignHeaderFn(kHeaderName)
225+
} ?: error("Expected non-null userData")
235226
}
236227

237228
/**
@@ -243,17 +234,17 @@ private fun signCallback(signingResult: CPointer<aws_signing_result>?, errorCode
243234
checkNotNull(signingResult) { "signing callback received null aws_signing_result" }
244235
checkNotNull(userData) { "signing callback received null user data" }
245236

246-
val (pinnedRequestToSign, callbackChannel) = userData
247-
.asStableRef<Pair<Pinned<CPointer<cnames.structs.aws_http_message>>, Channel<ByteArray>>>()
248-
.get()
237+
userData.withDereferenced<Pair<Pinned<CPointer<cnames.structs.aws_http_message>>, Channel<ByteArray>>> { pair ->
238+
val (pinnedRequestToSign, callbackChannel) = pair
249239

250-
val requestToSign = pinnedRequestToSign.get()
240+
val requestToSign = pinnedRequestToSign.get()
251241

252-
awsAssertOpSuccess(aws_apply_signing_result_to_http_request(requestToSign, Allocator.Default.allocator, signingResult)) {
253-
"aws_apply_signing_result_to_http_request"
254-
}
242+
awsAssertOpSuccess(aws_apply_signing_result_to_http_request(requestToSign, Allocator.Default.allocator, signingResult)) {
243+
"aws_apply_signing_result_to_http_request"
244+
}
255245

256-
runBlocking { callbackChannel.send(signingResult.getSignature()) }
246+
runBlocking { callbackChannel.send(signingResult.getSignature()) }
247+
}
257248
}
258249

259250
/**
@@ -264,8 +255,9 @@ private fun signChunkCallback(signingResult: CPointer<aws_signing_result>?, erro
264255
checkNotNull(signingResult) { "signing callback received null aws_signing_result" }
265256
checkNotNull(userData) { "signing callback received null user data" }
266257

267-
val callbackChannel = userData.asStableRef<Channel<ByteArray>>().get()
268-
runBlocking { callbackChannel.send(signingResult.getSignature()) }
258+
userData.withDereferenced<Channel<ByteArray>> { callbackChannel ->
259+
runBlocking { callbackChannel.send(signingResult.getSignature()) }
260+
}
269261
}
270262

271263
private fun Credentials.toNativeCredentials(): CPointer<cnames.structs.aws_credentials>? = aws_credentials_new_from_string(

aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpClientConnectionManagerNative.kt

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,10 @@ private fun SocketDomain.toNativeSocketDomain() = when (this) {
183183
}
184184

185185
private fun onShutdownComplete(userdata: COpaquePointer?) {
186-
if (userdata == null) return
187-
val notify = userdata.asStableRef<ShutdownChannel>()
188-
with(notify.get()) {
189-
trySend(Unit)
190-
close()
186+
userdata?.withDereferenced<ShutdownChannel>(dispose = true) { notify ->
187+
notify.trySend(Unit)
188+
notify.close()
191189
}
192-
notify.dispose()
193190
}
194191

195192
private data class HttpConnectionAcquisitionRequest(
@@ -202,20 +199,16 @@ private fun onConnectionAcquired(
202199
errCode: Int,
203200
userdata: COpaquePointer?,
204201
) {
205-
if (userdata == null) return
206-
val stableRef = userdata.asStableRef<HttpConnectionAcquisitionRequest>()
207-
val request = stableRef.get()
208-
209-
when {
210-
errCode != AWS_OP_SUCCESS -> request.cont.resumeWithException(HttpException(errCode))
211-
conn == null -> request.cont.resumeWithException(
212-
CrtRuntimeException("acquireConnection(): http connection null", ec = errCode),
213-
)
214-
else -> {
215-
val kconn = HttpClientConnectionNative(request.manager, conn)
216-
request.cont.resume(kconn)
202+
userdata?.withDereferenced<HttpConnectionAcquisitionRequest>(dispose = true) { request ->
203+
when {
204+
errCode != AWS_OP_SUCCESS -> request.cont.resumeWithException(HttpException(errCode))
205+
conn == null -> request.cont.resumeWithException(
206+
CrtRuntimeException("acquireConnection(): http connection null", ec = errCode),
207+
)
208+
else -> {
209+
val kconn = HttpClientConnectionNative(request.manager, conn)
210+
request.cont.resume(kconn)
211+
}
217212
}
218213
}
219-
220-
stableRef.dispose()
221214
}

aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpClientConnectionNative.kt

Lines changed: 69 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,7 @@ package aws.sdk.kotlin.crt.http
77
import aws.sdk.kotlin.crt.*
88
import aws.sdk.kotlin.crt.io.Buffer
99
import aws.sdk.kotlin.crt.io.ByteCursorBuffer
10-
import aws.sdk.kotlin.crt.util.asAwsByteCursor
11-
import aws.sdk.kotlin.crt.util.initFromCursor
12-
import aws.sdk.kotlin.crt.util.toKString
13-
import aws.sdk.kotlin.crt.util.use
14-
import aws.sdk.kotlin.crt.util.withAwsByteCursor
10+
import aws.sdk.kotlin.crt.util.*
1511
import kotlinx.atomicfu.atomic
1612
import kotlinx.cinterop.*
1713
import libcrt.*
@@ -87,105 +83,100 @@ private class HttpStreamContext(
8783
val nativeReq: CPointer<cnames.structs.aws_http_message>,
8884
)
8985

86+
private fun callbackError(): Int = aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
87+
9088
private fun onResponseHeaders(
9189
nativeStream: CPointer<cnames.structs.aws_http_stream>?,
9290
blockType: aws_http_header_block,
9391
headerArray: CPointer<aws_http_header>?,
9492
numHeaders: size_t,
9593
userdata: COpaquePointer?,
96-
): Int {
97-
val ctxStableRef = userdata?.asStableRef<HttpStreamContext>() ?: return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
98-
ctxStableRef.use {
99-
val ctx = it.get()
100-
val stream = ctx.stream ?: return AWS_OP_ERR
101-
102-
val hdrCnt = numHeaders.toInt()
103-
val headers: List<HttpHeader>? = if (hdrCnt > 0 && headerArray != null) {
104-
val kheaders = mutableListOf<HttpHeader>()
105-
for (i in 0 until hdrCnt) {
106-
val nativeHdr = headerArray[i]
107-
val hdr = HttpHeader(nativeHdr.name.toKString(), nativeHdr.value.toKString())
108-
kheaders.add(hdr)
94+
): Int =
95+
userdata?.withDereferenced<HttpStreamContext, _> { ctx ->
96+
ctx.stream?.let { stream ->
97+
val hdrCnt = numHeaders.toInt()
98+
val headers: List<HttpHeader>? = if (hdrCnt > 0 && headerArray != null) {
99+
val kheaders = mutableListOf<HttpHeader>()
100+
for (i in 0 until hdrCnt) {
101+
val nativeHdr = headerArray[i]
102+
val hdr = HttpHeader(nativeHdr.name.toKString(), nativeHdr.value.toKString())
103+
kheaders.add(hdr)
104+
}
105+
kheaders
106+
} else {
107+
null
109108
}
110-
kheaders
111-
} else {
112-
null
113-
}
114109

115-
try {
116-
ctx.handler.onResponseHeaders(stream, stream.responseStatusCode, blockType.value.toInt(), headers)
117-
} catch (ex: Exception) {
118-
log(LogLevel.Error, "onResponseHeaders: $ex")
119-
return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
110+
try {
111+
ctx.handler.onResponseHeaders(stream, stream.responseStatusCode, blockType.value.toInt(), headers)
112+
AWS_OP_SUCCESS
113+
} catch (ex: Exception) {
114+
log(LogLevel.Error, "onResponseHeaders: $ex")
115+
null
116+
}
120117
}
121-
122-
return AWS_OP_SUCCESS
123-
}
124-
}
118+
} ?: callbackError()
125119

126120
private fun onResponseHeaderBlockDone(
127121
nativeStream: CPointer<cnames.structs.aws_http_stream>?,
128122
blockType: aws_http_header_block,
129123
userdata: COpaquePointer?,
130-
): Int {
131-
val ctx = userdata?.asStableRef<HttpStreamContext>()?.get() ?: return AWS_OP_ERR
132-
val stream = ctx.stream ?: return AWS_OP_ERR
133-
134-
try {
135-
ctx.handler.onResponseHeadersDone(stream, blockType.value.toInt())
136-
} catch (ex: Exception) {
137-
log(LogLevel.Error, "onResponseHeaderBlockDone: $ex")
138-
return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
139-
}
140-
141-
return AWS_OP_SUCCESS
142-
}
124+
): Int =
125+
userdata?.withDereferenced<HttpStreamContext, _> { ctx ->
126+
ctx.stream?.let { stream ->
127+
try {
128+
ctx.handler.onResponseHeadersDone(stream, blockType.value.toInt())
129+
AWS_OP_SUCCESS
130+
} catch (ex: Exception) {
131+
log(LogLevel.Error, "onResponseHeaderBlockDone: $ex")
132+
null
133+
}
134+
}
135+
} ?: callbackError()
143136

144137
private fun onIncomingBody(
145138
nativeStream: CPointer<cnames.structs.aws_http_stream>?,
146139
data: CPointer<aws_byte_cursor>?,
147140
userdata: COpaquePointer?,
148-
): Int {
149-
val ctx = userdata?.asStableRef<HttpStreamContext>()?.get() ?: return AWS_OP_ERR
150-
val stream = ctx.stream ?: return AWS_OP_ERR
151-
152-
try {
153-
val body = if (data != null) ByteCursorBuffer(data) else Buffer.Empty
154-
val windowIncrement = ctx.handler.onResponseBody(stream, body)
155-
if (windowIncrement < 0) {
156-
return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
157-
}
158-
159-
if (windowIncrement > 0) {
160-
aws_http_stream_update_window(nativeStream, windowIncrement.convert())
141+
): Int =
142+
userdata?.withDereferenced<HttpStreamContext, _> { ctx ->
143+
ctx.stream?.let { stream ->
144+
try {
145+
val body = if (data != null) ByteCursorBuffer(data) else Buffer.Empty
146+
val windowIncrement = ctx.handler.onResponseBody(stream, body)
147+
148+
if (windowIncrement < 0) {
149+
null
150+
} else {
151+
if (windowIncrement > 0) {
152+
aws_http_stream_update_window(nativeStream, windowIncrement.convert())
153+
}
154+
AWS_OP_SUCCESS
155+
}
156+
} catch (ex: Exception) {
157+
log(LogLevel.Error, "onIncomingBody: $ex")
158+
null
159+
}
161160
}
162-
} catch (ex: Exception) {
163-
log(LogLevel.Error, "onIncomingBody: $ex")
164-
return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
165-
}
166-
167-
return AWS_OP_SUCCESS
168-
}
161+
} ?: callbackError()
169162

170163
private fun onStreamComplete(
171164
nativeStream: CPointer<cnames.structs.aws_http_stream>?,
172165
errorCode: Int,
173166
userdata: COpaquePointer?,
174167
) {
175-
val stableRef = userdata?.asStableRef<HttpStreamContext>() ?: return
176-
val ctx = stableRef.get()
177-
val stream = ctx.stream ?: return
178-
179-
try {
180-
ctx.handler.onResponseComplete(stream, errorCode)
181-
} catch (ex: Exception) {
182-
log(LogLevel.Error, "onStreamComplete: $ex")
183-
// close connection if callback throws an exception
184-
aws_http_connection_close(aws_http_stream_get_connection(nativeStream))
185-
} finally {
186-
// cleanup stream resources
187-
stableRef.dispose()
188-
aws_http_message_destroy(ctx.nativeReq)
168+
userdata?.withDereferenced<HttpStreamContext>(dispose = true) { ctx ->
169+
try {
170+
val stream = ctx.stream ?: return
171+
ctx.handler.onResponseComplete(stream, errorCode)
172+
} catch (ex: Exception) {
173+
log(LogLevel.Error, "onStreamComplete: $ex")
174+
// close connection if callback throws an exception
175+
aws_http_connection_close(aws_http_stream_get_connection(nativeStream))
176+
} finally {
177+
// cleanup request object
178+
aws_http_message_release(ctx.nativeReq)
179+
}
189180
}
190181
}
191182

aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpStreamNative.kt

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import aws.sdk.kotlin.crt.NativeHandle
1010
import aws.sdk.kotlin.crt.awsAssertOpSuccess
1111
import aws.sdk.kotlin.crt.util.asAwsByteCursor
1212
import aws.sdk.kotlin.crt.util.use
13+
import aws.sdk.kotlin.crt.util.withDereferenced
1314
import kotlinx.atomicfu.atomic
1415
import kotlinx.cinterop.*
1516
import libcrt.*
@@ -67,12 +68,13 @@ internal class HttpStreamNative(
6768
throw CrtRuntimeException("aws_input_stream_new_from_cursor()")
6869
}
6970

70-
StableRef.create(WriteChunkRequest(cont, byteBuf, stream)).use { req ->
71+
val req = WriteChunkRequest(cont, byteBuf, stream)
72+
StableRef.create(req).use { stableRef ->
7173
val chunkOpts = cValue<aws_http1_chunk_options> {
7274
chunk_data_size = chunkData.size.convert()
7375
chunk_data = stream
7476
on_complete = staticCFunction(::onWriteChunkComplete)
75-
user_data = req.asCPointer()
77+
user_data = stableRef.asCPointer()
7678
}
7779
awsAssertOpSuccess(
7880
aws_http1_stream_write_chunk(ptr, chunkOpts),
@@ -113,19 +115,18 @@ private fun onWriteChunkComplete(
113115
userData: COpaquePointer?,
114116
) {
115117
if (stream == null) return
116-
val stableRef = userData?.asStableRef<WriteChunkRequest>() ?: return
117-
val req = stableRef.get()
118-
when {
119-
errCode != AWS_OP_SUCCESS -> req.cont.resumeWithException(HttpException(errCode))
120-
else -> req.cont.resume(Unit)
118+
userData?.withDereferenced<WriteChunkRequest> { req ->
119+
checkNotNull(req) { "Received null request in onWriteChunkComplete" }
120+
when {
121+
errCode != AWS_OP_SUCCESS -> req.cont.resumeWithException(HttpException(errCode))
122+
else -> req.cont.resume(Unit)
123+
}
124+
cleanupWriteChunkCbData(req)
121125
}
122-
cleanupWriteChunkCbData(stableRef)
123126
}
124127

125-
private fun cleanupWriteChunkCbData(stableRef: StableRef<WriteChunkRequest>) {
126-
val req = stableRef.get()
128+
private fun cleanupWriteChunkCbData(req: WriteChunkRequest) {
127129
aws_input_stream_destroy(req.inputStream)
128130
aws_byte_buf_clean_up(req.chunkData)
129131
Allocator.Default.free(req.inputStream)
130-
stableRef.dispose()
131132
}

0 commit comments

Comments
 (0)