Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Local streaming feature #16

Closed
wants to merge 7 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update SDK to 0.0.64.rc1
cmodi-meta committed Jan 14, 2025
commit b6a1e501c7a40103cb8f71af96f385ed12eb7b0f
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
@@ -4,5 +4,5 @@ plugins {

allprojects {
group = "com.llama.llamastack"
version = "0.0.63"
version = "0.0.64.rc1"
}
Original file line number Diff line number Diff line change
@@ -27,9 +27,11 @@ constructor(
private var modelName: String = ""

private var sequenceLengthKey: String = "seq_len"
private var stopToken: String = ""

override fun onResult(p0: String?) {
if (PromptFormatLocal.getStopTokens(modelName).any { it == p0 }) {
stopToken = p0!!
onResultComplete = true
return
}
@@ -55,7 +57,7 @@ constructor(
params: InferenceChatCompletionParams,
requestOptions: RequestOptions
): InferenceChatCompletionResponse {
resultMessage = ""
clearElements()
val mModule = clientOptions.llamaModule
modelName = params.modelId()
val formattedPrompt =
@@ -79,7 +81,7 @@ constructor(
onResultComplete = false
onStatsComplete = false

return buildInferenceChatCompletionResponse(resultMessage, statsMetric)
return buildInferenceChatCompletionResponse(resultMessage, statsMetric, stopToken)
}

override fun chatCompletionStreaming(
@@ -109,4 +111,9 @@ constructor(
): EmbeddingsResponse {
TODO("Not yet implemented")
}

fun clearElements() {
resultMessage = ""
stopToken = ""
}
}
Original file line number Diff line number Diff line change
@@ -20,6 +20,18 @@ constructor(
TODO("Not yet implemented")
}

override fun toolgroups(): ToolgroupService {
TODO("Not yet implemented")
}

override fun tools(): ToolService {
TODO("Not yet implemented")
}

override fun toolRuntime(): ToolRuntimeService {
TODO("Not yet implemented")
}

override fun telemetry(): TelemetryService {
TODO("Not yet implemented")
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.llama.llamastack.client.local.util

import com.llama.llamastack.models.CompletionMessage
import com.llama.llamastack.models.InferenceChatCompletionParams
import com.llama.llamastack.models.SystemMessage
import com.llama.llamastack.models.UserMessage
@@ -81,7 +80,9 @@ object PromptFormatLocal {
}
message.isCompletionMessage() -> {
// assistant message
val completionMessage: CompletionMessage? = message.completionMessage()
val completionMessage:
InferenceChatCompletionParams.Message.CompletionMessage? =
message.completionMessage()
val content: String? = completionMessage?.content()?.string()
if (content != null) {
format =
Original file line number Diff line number Diff line change
@@ -1,28 +1,40 @@
package com.llama.llamastack.client.local.util

import com.llama.llamastack.core.JsonValue
import com.llama.llamastack.models.CompletionMessage
import com.llama.llamastack.models.InferenceChatCompletionResponse
import com.llama.llamastack.models.InterleavedContent
import com.llama.llamastack.models.ToolCall
import java.util.UUID

fun buildInferenceChatCompletionResponse(
response: String,
stats: Float
stats: Float,
stopToken: String
): InferenceChatCompletionResponse {
// check for prefix [ and suffix ] if so then tool call.
// parse for "toolName", "additionalProperties"

var completionMessage =
if (response.startsWith("[") && response.endsWith("]")) {
// custom tool call
CompletionMessage.builder()
InferenceChatCompletionResponse.ChatCompletionResponse.CompletionMessage.builder()
.toolCalls(createCustomToolCalls(response))
.content(InterleavedContent.ofString(""))
.role(
InferenceChatCompletionResponse.ChatCompletionResponse.CompletionMessage.Role
.ASSISTANT
)
.stopReason(mapStopTokenToReason(stopToken))
.build()
} else {
CompletionMessage.builder().content(InterleavedContent.ofString(response)).build()
InferenceChatCompletionResponse.ChatCompletionResponse.CompletionMessage.builder()
.toolCalls(listOf())
.content(InterleavedContent.ofString(response))
.role(
InferenceChatCompletionResponse.ChatCompletionResponse.CompletionMessage.Role
.ASSISTANT
)
.stopReason(mapStopTokenToReason(stopToken))
.build()
}

var inferenceChatCompletionResponse =
@@ -67,3 +79,18 @@ fun createCustomToolCalls(response: String): List<ToolCall> {

return toolCalls.toList()
}

fun mapStopTokenToReason(
stopToken: String
): InferenceChatCompletionResponse.ChatCompletionResponse.CompletionMessage.StopReason =
when (stopToken) {
"<|eot_id|>" ->
InferenceChatCompletionResponse.ChatCompletionResponse.CompletionMessage.StopReason
.END_OF_TURN
"<|eom_id|>" ->
InferenceChatCompletionResponse.ChatCompletionResponse.CompletionMessage.StopReason
.END_OF_MESSAGE
else ->
InferenceChatCompletionResponse.ChatCompletionResponse.CompletionMessage.StopReason
.OUT_OF_TOKENS
}
Original file line number Diff line number Diff line change
@@ -22,11 +22,20 @@ import com.llama.llamastack.services.blocking.ScoringService
import com.llama.llamastack.services.blocking.ShieldService
import com.llama.llamastack.services.blocking.SyntheticDataGenerationService
import com.llama.llamastack.services.blocking.TelemetryService
import com.llama.llamastack.services.blocking.ToolRuntimeService
import com.llama.llamastack.services.blocking.ToolService
import com.llama.llamastack.services.blocking.ToolgroupService

interface LlamaStackClientClient {

fun async(): LlamaStackClientClientAsync

fun toolgroups(): ToolgroupService

fun tools(): ToolService

fun toolRuntime(): ToolRuntimeService

fun agents(): AgentService

fun batchInference(): BatchInferenceService
Original file line number Diff line number Diff line change
@@ -22,11 +22,20 @@ import com.llama.llamastack.services.async.ScoringServiceAsync
import com.llama.llamastack.services.async.ShieldServiceAsync
import com.llama.llamastack.services.async.SyntheticDataGenerationServiceAsync
import com.llama.llamastack.services.async.TelemetryServiceAsync
import com.llama.llamastack.services.async.ToolRuntimeServiceAsync
import com.llama.llamastack.services.async.ToolServiceAsync
import com.llama.llamastack.services.async.ToolgroupServiceAsync

interface LlamaStackClientClientAsync {

fun sync(): LlamaStackClientClient

fun toolgroups(): ToolgroupServiceAsync

fun tools(): ToolServiceAsync

fun toolRuntime(): ToolRuntimeServiceAsync

fun agents(): AgentServiceAsync

fun batchInference(): BatchInferenceServiceAsync
Original file line number Diff line number Diff line change
@@ -44,6 +44,12 @@ import com.llama.llamastack.services.async.SyntheticDataGenerationServiceAsync
import com.llama.llamastack.services.async.SyntheticDataGenerationServiceAsyncImpl
import com.llama.llamastack.services.async.TelemetryServiceAsync
import com.llama.llamastack.services.async.TelemetryServiceAsyncImpl
import com.llama.llamastack.services.async.ToolRuntimeServiceAsync
import com.llama.llamastack.services.async.ToolRuntimeServiceAsyncImpl
import com.llama.llamastack.services.async.ToolServiceAsync
import com.llama.llamastack.services.async.ToolServiceAsyncImpl
import com.llama.llamastack.services.async.ToolgroupServiceAsync
import com.llama.llamastack.services.async.ToolgroupServiceAsyncImpl

class LlamaStackClientClientAsyncImpl
constructor(
@@ -61,6 +67,16 @@ constructor(
// Pass the original clientOptions so that this client sets its own User-Agent.
private val sync: LlamaStackClientClient by lazy { LlamaStackClientClientImpl(clientOptions) }

private val toolgroups: ToolgroupServiceAsync by lazy {
ToolgroupServiceAsyncImpl(clientOptionsWithUserAgent)
}

private val tools: ToolServiceAsync by lazy { ToolServiceAsyncImpl(clientOptionsWithUserAgent) }

private val toolRuntime: ToolRuntimeServiceAsync by lazy {
ToolRuntimeServiceAsyncImpl(clientOptionsWithUserAgent)
}

private val agents: AgentServiceAsync by lazy {
AgentServiceAsyncImpl(clientOptionsWithUserAgent)
}
@@ -141,6 +157,12 @@ constructor(

override fun sync(): LlamaStackClientClient = sync

override fun toolgroups(): ToolgroupServiceAsync = toolgroups

override fun tools(): ToolServiceAsync = tools

override fun toolRuntime(): ToolRuntimeServiceAsync = toolRuntime

override fun agents(): AgentServiceAsync = agents

override fun batchInference(): BatchInferenceServiceAsync = batchInference
Original file line number Diff line number Diff line change
@@ -44,6 +44,12 @@ import com.llama.llamastack.services.blocking.SyntheticDataGenerationService
import com.llama.llamastack.services.blocking.SyntheticDataGenerationServiceImpl
import com.llama.llamastack.services.blocking.TelemetryService
import com.llama.llamastack.services.blocking.TelemetryServiceImpl
import com.llama.llamastack.services.blocking.ToolRuntimeService
import com.llama.llamastack.services.blocking.ToolRuntimeServiceImpl
import com.llama.llamastack.services.blocking.ToolService
import com.llama.llamastack.services.blocking.ToolServiceImpl
import com.llama.llamastack.services.blocking.ToolgroupService
import com.llama.llamastack.services.blocking.ToolgroupServiceImpl

class LlamaStackClientClientImpl
constructor(
@@ -63,6 +69,16 @@ constructor(
LlamaStackClientClientAsyncImpl(clientOptions)
}

private val toolgroups: ToolgroupService by lazy {
ToolgroupServiceImpl(clientOptionsWithUserAgent)
}

private val tools: ToolService by lazy { ToolServiceImpl(clientOptionsWithUserAgent) }

private val toolRuntime: ToolRuntimeService by lazy {
ToolRuntimeServiceImpl(clientOptionsWithUserAgent)
}

private val agents: AgentService by lazy { AgentServiceImpl(clientOptionsWithUserAgent) }

private val batchInference: BatchInferenceService by lazy {
@@ -125,6 +141,12 @@ constructor(

override fun async(): LlamaStackClientClientAsync = async

override fun toolgroups(): ToolgroupService = toolgroups

override fun tools(): ToolService = tools

override fun toolRuntime(): ToolRuntimeService = toolRuntime

override fun agents(): AgentService = agents

override fun batchInference(): BatchInferenceService = batchInference
Original file line number Diff line number Diff line change
@@ -64,6 +64,12 @@ private constructor(

fun baseUrl(baseUrl: String) = apply { this.baseUrl = baseUrl }

fun responseValidation(responseValidation: Boolean) = apply {
this.responseValidation = responseValidation
}

fun maxRetries(maxRetries: Int) = apply { this.maxRetries = maxRetries }

fun headers(headers: Headers) = apply {
this.headers.clear()
putAllHeaders(headers)
@@ -144,12 +150,6 @@ private constructor(

fun removeAllQueryParams(keys: Set<String>) = apply { queryParams.removeAll(keys) }

fun responseValidation(responseValidation: Boolean) = apply {
this.responseValidation = responseValidation
}

fun maxRetries(maxRetries: Int) = apply { this.maxRetries = maxRetries }

fun fromEnv() = apply {}

fun build(): ClientOptions {
Loading