Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -210,19 +210,27 @@ public open class GoogleLLMClient @JvmOverloads constructor(
response.candidates.firstOrNull()?.let { candidate ->
candidate.content?.parts?.forEachIndexed { index, part ->
when (part) {
is GooglePart.Text -> {
if (part.thought == true) {
emitReasoningDelta(
id = part.thoughtSignature,
text = part.text,
index = index,
)
} else {
emitTextDelta(part.text, index)
}
}

is GooglePart.FunctionCall -> {
emitToolCallDelta(
id = part.functionCall.id,
name = part.functionCall.name,
args = part.functionCall.args?.toString() ?: "{}",
index = index
index = index,
)
}

is GooglePart.Text -> {
emitTextDelta(part.text, index)
}

else -> Unit
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,40 @@ package ai.koog.prompt.executor.clients.google
import ai.koog.agents.core.tools.ToolDescriptor
import ai.koog.agents.core.tools.ToolParameterDescriptor
import ai.koog.agents.core.tools.ToolParameterType
import ai.koog.http.client.KoogHttpClient
import ai.koog.prompt.dsl.Prompt
import ai.koog.prompt.executor.clients.google.models.GoogleCandidate
import ai.koog.prompt.executor.clients.google.models.GoogleContent
import ai.koog.prompt.executor.clients.google.models.GoogleData
import ai.koog.prompt.executor.clients.google.models.GoogleFunctionCallingMode
import ai.koog.prompt.executor.clients.google.models.GooglePart
import ai.koog.prompt.executor.clients.google.models.GoogleRequest
import ai.koog.prompt.executor.clients.google.models.GoogleResponse
import ai.koog.prompt.executor.clients.google.models.GoogleThinkingConfig
import ai.koog.prompt.message.AttachmentContent
import ai.koog.prompt.message.ContentPart
import ai.koog.prompt.message.Message
import ai.koog.prompt.message.RequestMetaInfo
import ai.koog.prompt.message.ResponseMetaInfo
import ai.koog.prompt.params.LLMParams
import ai.koog.prompt.streaming.StreamFrame
import io.kotest.matchers.collections.shouldContain
import io.kotest.matchers.collections.shouldHaveSize
import io.kotest.matchers.shouldBe
import io.kotest.matchers.shouldNotBe
import io.kotest.matchers.types.shouldBeInstanceOf
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.emptyFlow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.test.runTest
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.buildJsonObject
import kotlinx.serialization.json.jsonArray
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import kotlin.reflect.KClass
import kotlin.test.Test

class GoogleLLMClientTest {
Expand Down Expand Up @@ -519,14 +529,98 @@ class GoogleLLMClientTest {
reasoning.encrypted shouldBe "thought-sig"
}

@Test
fun `executeStreaming emits reasoning delta for thought text parts`() = runTest {
val model = GoogleModels.Gemini2_5Pro
val thoughtSignature = "thought-signature-dummy"
val finalAnswer = "final answer"
val thoughtText = "internal thinking"
val finishReason = "STOP"

val transport = object : KoogHttpClient {
override val clientName: String = "GoogleStreamingTestClient"

override suspend fun <R : Any> get(
path: String,
responseType: KClass<R>,
parameters: Map<String, String>,
): R = error("GET is not expected in this test")

override suspend fun <T : Any, R : Any> post(
path: String,
request: T,
requestBodyType: KClass<T>,
responseType: KClass<R>,
parameters: Map<String, String>,
): R = error("POST is not expected in this test")

override fun <T : Any, R : Any, O : Any> sse(
path: String,
request: T,
requestBodyType: KClass<T>,
dataFilter: (String?) -> Boolean,
decodeStreamingResponse: (String) -> R,
processStreamingChunk: (R) -> O?,
parameters: Map<String, String>,
): Flow<O> {
path shouldBe "v1beta/models/${model.id}:streamGenerateContent"
request.shouldBeInstanceOf<GoogleRequest>()

val response = GoogleResponse(
candidates = listOf(
GoogleCandidate(
content = GoogleContent(
role = "model",
parts = listOf(
GooglePart.Text(
text = thoughtText,
thought = true,
thoughtSignature = thoughtSignature
),
GooglePart.Text(text = finalAnswer)
)
),
finishReason = finishReason,
index = 0
)
)
)

val chunk = processStreamingChunk(response as R)
return if (chunk != null) flowOf(chunk) else emptyFlow()
}

override fun close(): Unit = Unit
}

val client = GoogleLLMClient(httpClient = transport)

val frames = client.executeStreaming(
prompt = Prompt(messages = listOf(Message.User("Hi", RequestMetaInfo.Empty)), id = "id"),
model = model,
).toList()

frames shouldBe listOf(
StreamFrame.ReasoningDelta(id = thoughtSignature, text = thoughtText, index = 0),
StreamFrame.ReasoningComplete(id = thoughtSignature, text = listOf(thoughtText), index = 0),
StreamFrame.TextDelta(finalAnswer, 1),
StreamFrame.TextComplete(finalAnswer, 1),
StreamFrame.End(finishReason, ResponseMetaInfo.Empty),
)
}

@Test
fun `createGoogleRequest includes Reasoning as Text part with thought=true`() {
val client = GoogleLLMClient(apiKey = "test")
val request = client.createGoogleRequest(
Prompt(
messages = listOf(
Message.User("query", RequestMetaInfo.Empty),
Message.Reasoning(content = "Previous thought", encrypted = "prev-sig", metaInfo = ResponseMetaInfo.Empty)
Message.Reasoning(
content = "Previous thought",
encrypted = "prev-sig",
metaInfo = ResponseMetaInfo.Empty
)
),
id = "id"
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2250,6 +2250,12 @@ internal sealed interface OpenAIStreamEvent {
val sequenceNumber: Int,
) : OpenAIStreamEvent

@Serializable
@SerialName("keepalive")
class ResponseKeepalive(
val sequenceNumber: Int,
) : OpenAIStreamEvent

@Serializable
class LogProbWithTop(val logprob: Double, val token: String, val topLogprobs: List<LogProb>)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
package ai.koog.prompt.executor.clients.openai

import ai.koog.http.client.KoogHttpClient
import ai.koog.prompt.dsl.Prompt
import ai.koog.prompt.dsl.prompt
import ai.koog.prompt.executor.clients.openai.models.Item
import ai.koog.prompt.executor.clients.openai.models.OpenAIInputStatus
import ai.koog.prompt.executor.clients.openai.models.OpenAIResponsesAPIResponse
import ai.koog.prompt.executor.clients.openai.models.OpenAIStreamEvent
import ai.koog.prompt.executor.clients.openai.models.OpenAITextConfig
import ai.koog.prompt.llm.LLMProvider
import ai.koog.prompt.message.Message
import ai.koog.prompt.message.RequestMetaInfo
import ai.koog.prompt.streaming.StreamFrame
import ai.koog.test.utils.CapturingKoogHttpClient
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.emptyFlow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.test.runTest
import kotlin.reflect.KClass
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertIs
Expand Down Expand Up @@ -58,4 +72,140 @@ class OpenAIPrimaryConstructorTest {
val message = assertIs<Message.Assistant>(responses.single())
assertEquals("Hello from KoogHttpClient", message.content)
}

@Test
fun `primary constructor should stream reasoning frames through provided koog http client`() = runTest {
val responsesPath = "v1/responses"
val reasoningId = "reasoning_123"
val reasoningDelta = "Thinking"
val reasoningContent = "Thinking complete"
val reasoningSummary = "Short summary"
val encryptedReasoning = "enc_123"
val responseId = "resp_123"
val inputTokens = 3
val outputTokens = 4
val reasoningTokens = 2
val totalTokens = 7

val transport = object : KoogHttpClient {
override val clientName: String = "StreamingOpenAIClient"

override suspend fun <R : Any> get(
path: String,
responseType: KClass<R>,
parameters: Map<String, String>,
): R = error("GET is not expected in this test")

override suspend fun <T : Any, R : Any> post(
path: String,
request: T,
requestBodyType: KClass<T>,
responseType: KClass<R>,
parameters: Map<String, String>,
): R = error("POST is not expected in this test")

override fun <T : Any, R : Any, O : Any> sse(
path: String,
request: T,
requestBodyType: KClass<T>,
dataFilter: (String?) -> Boolean,
decodeStreamingResponse: (String) -> R,
processStreamingChunk: (R) -> O?,
parameters: Map<String, String>,
): Flow<O> {
assertEquals(responsesPath, path)

val events = listOfNotNull(
processStreamingChunk(
OpenAIStreamEvent.ResponseReasoningTextDelta(
itemId = reasoningId,
outputIndex = 0,
contentIndex = 0,
delta = reasoningDelta,
sequenceNumber = 1
) as R
),
processStreamingChunk(OpenAIStreamEvent.ResponseKeepalive(sequenceNumber = 2) as R),
processStreamingChunk(
OpenAIStreamEvent.ResponseOutputItemDone(
item = Item.Reasoning(
id = reasoningId,
summary = listOf(Item.Reasoning.Summary(reasoningSummary)),
content = listOf(Item.Reasoning.Content(reasoningContent)),
encryptedContent = encryptedReasoning,
status = OpenAIInputStatus.COMPLETED
),
outputIndex = 0,
sequenceNumber = 3
) as R
),
processStreamingChunk(
OpenAIStreamEvent.ResponseCompleted(
response = OpenAIResponsesAPIResponse(
created = 1716920005,
id = responseId,
model = "gpt-5",
output = emptyList(),
parallelToolCalls = false,
status = OpenAIInputStatus.COMPLETED,
text = OpenAITextConfig(),
usage = OpenAIResponsesAPIResponse.Usage(
inputTokens = inputTokens,
inputTokensDetails = OpenAIResponsesAPIResponse.Usage.InputTokensDetails(cachedTokens = 0),
outputTokens = outputTokens,
outputTokensDetails = OpenAIResponsesAPIResponse.Usage.OutputTokensDetails(reasoningTokens = reasoningTokens),
totalTokens = totalTokens
)
),
sequenceNumber = 4
) as R
)
)

return if (events.isNotEmpty()) {
flow {
events.forEach { emit(it) }
}
} else {
emptyFlow()
}
}

override fun close(): Unit = Unit
}
val client = OpenAILLMClient(
settings = OpenAIClientSettings(baseUrl = "https://unused.test"),
httpClient = transport
)

val frames = client.executeStreaming(
prompt = Prompt(
messages = listOf(Message.User("Hello?", RequestMetaInfo.Empty)),
id = "test",
params = OpenAIResponsesParams()
),
model = OpenAIModels.Chat.GPT4o
).toList()

assertEquals(3, frames.size)
assertEquals(
StreamFrame.ReasoningDelta(id = reasoningId, text = reasoningDelta, index = 0),
frames[0]
)
assertEquals(
StreamFrame.ReasoningComplete(
id = reasoningId,
text = listOf(reasoningContent),
summary = listOf(reasoningSummary),
encrypted = encryptedReasoning,
index = 0
),
frames[1]
)
val end = assertIs<StreamFrame.End>(frames[2])
assertEquals(null, end.finishReason)
assertEquals(totalTokens, end.metaInfo.totalTokensCount)
assertEquals(inputTokens, end.metaInfo.inputTokensCount)
assertEquals(outputTokens, end.metaInfo.outputTokensCount)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,15 @@ class OpenAIStreamEventsTest {
}
}

@Test
fun `test keepalive event`() = runWithBothJsonConfigurations("keepalive event") { json ->
json.decodeFromString<OpenAIStreamEvent.ResponseKeepalive>(
json.encodeToString(OpenAIStreamEvent.ResponseKeepalive(48))
).shouldNotBeNull {
sequenceNumber shouldBe 48
}
}

@Test
fun `test stream error event`() = runWithBothJsonConfigurations("stream error") { json ->
json.decodeFromString<OpenAIStreamEvent.Error>(
Expand Down