diff --git a/CLAUDE.md b/CLAUDE.md index f2db9fe..58450b6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ## Project Overview -**Llama Bro** is an Android SDK for on-device LLM inference, wrapping [llama.cpp](https://github.com/ggerganov/llama.cpp) via JNI. It consists of two modules: +**Llama Bro** is an Android SDK for on-device LLM inference, wrapping [llama.cpp](https://github.com/ggml-org/llama.cpp) via JNI. It consists of two modules: - **`:sdk`** — reusable Android library (published to JitPack) - **`:app`** — demo application showcasing the SDK @@ -25,20 +25,17 @@ git submodule update --init --recursive # Install debug app on connected device ./gradlew :app:installDebug -# Run unit tests (SDK only; no instrumentation tests exist) -./gradlew :sdk:test +# Run unit tests (SDK only) +./gradlew :sdk:testDebugUnitTest # Run a single test class -./gradlew :sdk:test --tests "com.suhel.llamabro.sdk.util.PromptFormatterTest" - -# Publish SDK to local Maven repository (used by JitPack) -./gradlew :sdk:publishToMavenLocal +./gradlew :sdk:testDebugUnitTest --tests "com.suhel.llamabro.sdk.chat.internal.LlamaChatSessionImplTest" # Clean ./gradlew clean ``` -**NDK requirement:** NDK 29.0.14206865 and CMake 3.22.1 must be installed via the Android SDK Manager. The project only builds for `arm64-v8a` — x86_64 emulators are not supported. +**NDK requirement:** NDK 29.0.14206865 and CMake 3.22.1 must be installed via the Android SDK Manager. The project only builds for `arm64-v8a`. ## Architecture @@ -49,40 +46,39 @@ UI (Jetpack Compose) ↓ ViewModels (MVVM, Hilt-injected) ↓ -Repositories (ChatRepository, ModelRepository) - ↓ -SDK Public API ───────────────────────────────────────────── +SDK Public API (internal implementations) LlamaEngine → LlamaSession → LlamaChatSession ↓ -JNI Bridge (llama_engine_jni.cpp, llama_session_jni.cpp) +Internal Pipeline (Declarative Flows) + session.generateFlow() -> Lexer -> Semantic Chunks -> Snapshot ↓ -Native C++ (session.cpp, engine.cpp → llama.cpp) +JNI Bridge (Kotlin ↔ Native Structs) + ↓ +Native C++ (llama.cpp) ``` ### SDK Module Three tiers of API, each building on the previous: -1. **`LlamaEngine`** — loads a GGUF model file from disk; creates sessions. Use `LlamaEngine.createFlow(modelConfig)` for reactive loading that emits `ResourceState`. +1. **`LlamaEngine`** — manages model weights. Use `LlamaEngine.createFlow(modelDefinition)` for reactive loading that emits `ResourceState`. -2. **`LlamaSession`** — low-level token control. Call `setSystemPrompt()`, then loop `prompt()` + `generate()` to produce tokens. Wrap using `createChatSession()` to get the high-level API. +2. **`LlamaSession`** — mutex-serialized token control. Call `setPrefixedPrompt()`, then use `generateFlow()` to stream native tokens via `channelFlow`. -3. **`LlamaChatSession`** — high-level conversational API. `completion(message)` returns `Flow`. Handles prompt template formatting, thinking-block extraction (`...`), and `OverflowStrategy`. +3. **`LlamaChatSession`** — high-level conversational API. `completion(ChatEvent.UserEvent)` returns `Flow`. Internally uses a DFA-based `AllocationOptimizedScanner` to extract text, thinking blocks, and tool calls. -**`ResourceState`** is the lifecycle ADT used throughout. It has subtypes `Loading(progress)`, `Success(value)`, `Failure(error)` and rich Flow extension operators (`flatMapResource`, `filterSuccess`, etc.) for composing resource loads. +**`ResourceState`** — lifecycle ADT (`Loading`, `Success`, `Failure`) with rich Flow extension operators (`flatMapResource`, `filterSuccess`). -**`PromptFormat`** / **`PromptFormats`** — chat template definitions. Built-in formats: `Llama3`, `Gemma3`, `ChatML` (Qwen/Yi), `Mistral`. Each model in `ModelZoo` references one of these. +**`ChatEvent`** — sealed hierarchy for conversation history. `AssistantEvent` is parts-based (Text, Thinking, ToolCall). -**`LlamaError`** — sealed error hierarchy (`ModelNotFound`, `ModelLoadFailed`, `ContextOverflow`, `DecodeFailed`, `Cancelled`, `NativeException`, etc.). +**`LlamaError`** — sealed error hierarchy mapped from native codes/exceptions. ### App Module Standard MVVM with Hilt DI: - -- **`ModelRepository`** — singleton managing model download/load/eject lifecycle. Download state is a FSM: `NotDownloaded → Downloading → Downloaded`. The currently loaded engine is exposed as `currentInferenceContextFlow: StateFlow`. -- **`ChatRepository`** — Room-backed CRUD for conversations and messages. -- **`ModelZoo`** — hardcoded list of 6 pre-curated GGUF models (SmolLM2 135M–1.7B, Qwen2.5 0.5B, Llama-3.2 1B, DeepSeek-R1 1.5B) with download URLs and default configs. -- Navigation uses type-safe `Route` sealed class with Jetpack Navigation Compose. +- **`ModelRepository`** — manages engine lifecycle. +- **`ChatRepository`** — Room-backed storage. +- **`ModelZoo`** — curated list of GGUF models with optimal `ModelDefinition` presets. ### JNI / Native @@ -94,12 +90,14 @@ Standard MVVM with Hilt DI: | Class | Purpose | |---|---| -| `ModelConfig` | Model path, `PromptFormat`, MMAP/MLOCK flags, thread count | -| `SessionConfig` | Context size, `OverflowStrategy`, `InferenceConfig`, `DecodeConfig` | -| `InferenceConfig` | Temperature, top-p/k, min-p, repeat penalty | -| `DecodeConfig` | Batch sizes for performance tuning | -| `PromptFormat` | Per-role prefix/suffix tokens, BOS/EOS, `` tag markers | +| `ModelDefinition` | `ModelLoadConfig` (path, threads, mmap) + `PromptFormat` + `FeatureMarker` | +| `SessionConfig` | `contextSize`, `OverflowStrategy`, `InferenceConfig`, `DecodeConfig` | +| `PromptFormat` | Role markers, `stopStrings`, prefix injection logic | ## Testing -Unit tests live in `sdk/src/test/` — currently only `PromptFormatterTest` covering chat template formatting. There are no instrumentation tests. New SDK behavior should be covered in this test source set. +Unit tests in `sdk/src/test/`: +- `AllocationOptimizedScannerTest` — DFA lexing logic. +- `PromptFormatterTest` — chat template serialization. +- `LlamaChatSessionImplTest` — full pipeline integration (replayed via FakeSession). +- `ResourceStateTest` — state transition logic. diff --git a/README.md b/README.md index 13737c6..6dde889 100644 --- a/README.md +++ b/README.md @@ -1,54 +1,144 @@ -# Llama Bro SDK Android +# [Llama Bro SDK](https://github.com/whyisitworking/llama-bro) -> **Run a full AI model in your pocket. On your terms. No servers. No subscriptions. No data leaving your phone.** +

+ Download APK + Version + API 24+ + License +

-Banner +

+ JitPack + ABI + Stars +

-[![Build](https://img.shields.io/github/actions/workflow/status/whyisitworking/llama-bro/build.yml?style=flat-square&logo=github&label=Build)](https://github.com/whyisitworking/llama-bro/actions/workflows/build.yml) -[![Release](https://img.shields.io/github/v/release/whyisitworking/llama-bro?style=flat-square&logo=android&label=Demo%20APK&color=success)](https://github.com/whyisitworking/llama-bro/releases/latest/download/LlamaBro-Demo.apk) -[![JitPack](https://img.shields.io/jitpack/v/github/whyisitworking/llama-bro?style=flat-square&logo=git&color=brightgreen)](https://jitpack.io/#whyisitworking/llama-bro) -[![API](https://img.shields.io/badge/API-24%2B-brightgreen.svg?style=flat-square&logo=android)](https://android-arsenal.com/api?level=24) -[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg?style=flat-square&logo=apache)](LICENSE) -[![ABI](https://img.shields.io/badge/ABI-arm64--v8a-orange.svg?style=flat-square&logo=arm)](https://developer.android.com/ndk/guides/abis) +**Run a full AI model in your pocket. On your terms. No servers. No subscriptions. No data leaving your phone.** -Llama Bro is a thin, performant Android SDK that runs quantized LLM models directly on-device via [llama.cpp](https://github.com/ggml-org/llama.cpp). Built with Kotlin coroutines and structured concurrency for modern Android development. Whether you're a privacy-focused builder, an offline-first enthusiast, or just curious about what's possible on a phone, Llama Bro makes AI inference as simple as a single Gradle dependency. +

+ Llama Bro Banner +

-### 🚀 Try the Demo App -Want to see the inference speed and reasoning capabilities in action before writing any code? +--- + +## The Problem with Cloud AI + +Every time you send a message to a cloud LLM, that message travels to a datacenter. It's logged, processed, and potentially used to train the next model. Your health questions, your legal queries, your private relationship advice — all of it leaves your device. + +**Llama Bro is the answer to that.** + +We wrap [llama.cpp](https://github.com/ggml-org/llama.cpp) in a clean, idiomatic Kotlin SDK so you can run state-of-the-art models — Llama 3, Gemma, DeepSeek-R1, Qwen 2.5 — directly on the device. No API keys. No usage limits. No data residency concerns. Your model, your hardware, your rules. + +--- -[![Download APK](https://img.shields.io/badge/Download_Demo_APK-Latest_Release-2ea44f?style=for-the-badge&logo=android)](https://github.com/whyisitworking/llama-bro/releases/latest/download/LlamaBro-Demo.apk) +## See it in Action
- + +
+ Real-time token streaming on Snapdragon 8 Elite. No cloud. No lag.
--- -## Why Llama Bro? +## What's New — Declarative Inference Pipeline + +The headline feature of the most recent architectural refactor is the **Declarative Inference Pipeline** — a fully reactive, allocation-optimized token processing engine that maps raw native output directly to your UI without a single blocking call. + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ 1. USER PROMPT │ +│ chat.completion(ChatEvent.UserEvent("Hello", think = true)) │ +└─────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ 2. PROMPT FORMATTER │ +│ Wraps the message in model-specific chat markers: │ +│ <|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n │ +└─────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ 3. NATIVE GENERATOR │ +│ llama_decode() → channelFlow { send(token) } │ +│ Running on Dispatchers.IO, legally cross-context. │ +└─────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ 4. DFA LEXER (AllocationOptimizedScanner) │ +│ Scans the raw token stream character-by-character. │ +│ Detects: text | ... | ...│ +│ Uses StringBuilder, not String concat — 0 GC pressure. │ +└─────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ 5. SEMANTIC CHUNKING │ +│ Emits typed chunks: TextChunk | ThinkingChunk | ToolChunk │ +│ Assembled into AssistantEvent.Part objects. │ +└─────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ 6. COMPLETION SNAPSHOT │ +│ Each emission: { message, tokensPerSecond, isComplete } │ +│ Your UI collects this — full content, always cumulative. │ +└─────────────────────────────────────────────────────────────────┘ +``` + +### How Pipeline Composition Works + +```kotlin +LlamaEngine.createFlow(modelDefinition) // Load model → ResourceState + .flatMapResource { engine -> // When loaded, create session + engine.createSessionFlow(sessionConfig) + } + .flatMapResource { session -> // When session ready, create chat + session.createChatSessionFlow(systemPrompt) + } + .filterSuccess() // Extract the chat session + .flatMapLatest { chat -> // On each user turn + chat.completion(userEvent) + } + .collect { snapshot -> // UI-ready snapshot, on every token + updateTextView(snapshot.message.text) + if (snapshot.isComplete) saveToDb(snapshot) + } +``` -**🔒 True Privacy** -No API keys. No telemetry. No models calling home. Your data never leaves the device. +No threading code. No callbacks. No lifecycle leaks. Cancellation is free. -**💰 Zero Token Cost** -Run models as much as you want—no usage limits, no payment APIs, no surprise bills. +--- -**⚡ Fast Local Inference** -Tap the device's SIMD capabilities (NEON, dotprod, i8mm) for real-time responses. ~30 tokens/second on Snapdragon 8 Elite with Llama 3.2 1B. +## Features -**📱 Built for Android** -Kotlin-native coroutine API. No threading headaches. No callback hell. Just `suspend fun` and `Flow`. +- **Zero-Allocation Streaming** — DFA-based scanner (`AllocationOptimizedScanner`) uses StringBuilder internally and avoids per-token heap allocations, keeping the UI thread smooth +- **Thinking Block Extraction** — First-class support for `...` in reasoning models (DeepSeek-R1, QwQ, MiniMax). Thinking text and response text are separated automatically +- **Declarative Flow API** — `ResourceState` ADT with `flatMapResource`, `filterSuccess`, `onEachLoading`, and `fold` operators for composing resource loads declaratively +- **Prompt Format Library** — 6 built-in chat templates (Gemma, Llama 3, ChatML, DeepSeek-R1, Mistral, Nemotron) + `QWEN_2_5` alias + support for fully custom formats, including "turn-start" injection for forcing thinking +- **Overflow Management** — 3 strategies for handling full KV caches: `Halt`, `ClearHistory`, `RollingWindow` — configurable per session +- **Type-Safe Errors** — `LlamaError` sealed class maps every native failure to a named subtype. No raw exceptions from the JNI boundary +- **History Replay** — `feedHistory(List)` pre-populates the KV cache with a prior conversation, so follow-up generations are contextual -**🎯 Production-Ready** -Thread-safe sessions. Memory-safe lifecycle management. Structured error handling. Works with Hilt. Works with architecture patterns you already use. +### Built-In Prompt Formats -**🧠 Reasoning Models Included** -Built-in thinking-block parsing for DeepSeek-R1, QwQ, and other reasoning models. See the model's thought process. +| Template | Protocol | Best For | +|-----------------|-------------------------------------------------------|---------------------------------------| +| `GEMMA` | `` / `` | Google Gemma / Gemma 2 / Gemma 3n | +| `LLAMA_3` | `<\|start_header_id\|>` / `<\|eot_id\|>` | Llama 3 / 3.1 / 3.2 / 3.3 | +| `CHAT_ML` | `<\|im_start\|>` / `<\|im_end\|>` | SmolLM2, Qwen 2.5, Yi, Hermes | +| `QWEN_2_5` | alias for `CHAT_ML` | Qwen 2.5 (convenient alias) | +| `DEEPSEEK_R1` | `<|begin of sentence|>` / `<|end of sentence|>` | DeepSeek-R1 / R1-Distill family | +| `MISTRAL` | `[INST]` / `[/INST]` | Mistral 7B, Mixtral 8x7B | +| `NEMOTRON` | `` / `` | NVIDIA Nemotron-Mini | --- -## Quick Start (3 minutes) +## Installation -### 1. Add the dependency +### 1. Add JitPack to your repositories ```kotlin // settings.gradle.kts @@ -59,467 +149,246 @@ dependencyResolutionManagement { maven { url = uri("https://jitpack.io") } } } +``` -// Check the JitPack/Github badge above for the latest version number. build.gradle.kts (app) +### 2. Add the dependency + +```kotlin +// build.gradle.kts (app) dependencies { implementation("com.github.whyisitworking:llama-bro:") } ``` -### 2. Download a model +Check the JitPack badge above for the latest version. + +--- -Grab a GGUF-quantised model from [Hugging Face](https://huggingface.co/models?library=gguf). **Recommended for first-time users:** +## Prerequisites -- **Gemma 3n (2B, Q4_K_M, ~3 GB)** — Best balance of speed and quality - - [unsloth/gemma-3n-E2B-it-GGUF](https://huggingface.co/unsloth/gemma-3n-E2B-it-GGUF) -- **Llama 3.2 (1B, Q4_K_M, ~600 MB)** — Ultra-lightweight; good for testing - - [bartowski/Llama-3.2-1B-Instruct-GGUF](https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF) -- **DeepSeek-R1 (7B, Q4_K_M, ~5 GB)** — For reasoning/thinking blocks - - [bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF](https://huggingface.co/bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF) +### Download a GGUF Model -**Quantisation guide:** `Q4_K_M` is the gold standard on mobile (best quality-to-speed tradeoff). For maximum speed on low-RAM devices, try `Q3_K_M` or `Q2_K`. For quality-first, try `Q5_K_M` (slower, larger). +Grab a GGUF-quantised model from [Hugging Face](https://huggingface.co/models?library=gguf). -### 3. Load and chat +**Recommended starting points:** + +| Model | Size | Format | Recommended Source | +|---|---|---|---| +| **Llama 3.2 1B** | ~600 MB | `LLAMA_3` | [bartowski/Llama-3.2-1B-Instruct-GGUF](https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF) | +| **Gemma 3n 2B** | ~3 GB | `GEMMA` | [unsloth/gemma-3n-E2B-it-GGUF](https://huggingface.co/unsloth/gemma-3n-E2B-it-GGUF) | +| **Qwen 2.5 0.5B** | ~400 MB | `QWEN_2_5` | [bartowski/Qwen2.5-0.5B-Instruct-GGUF](https://huggingface.co/bartowski/Qwen2.5-0.5B-Instruct-GGUF) | +| **DeepSeek-R1 7B** | ~5 GB | `DEEPSEEK_R1` | [bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF](https://huggingface.co/bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF) | + +**Quantization guide:** `Q4_K_M` is the mobile sweet spot — best quality-to-speed tradeoff. Go `Q3_K_M` for RAM-constrained devices. Go `Q5_K_M` for maximum quality (larger, slower). + +--- + +## Quick Start ```kotlin -import com.suhel.llamabro.sdk.* -import com.suhel.llamabro.sdk.model.flatMapResource -import com.suhel.llamabro.sdk.model.filterSuccess -import com.suhel.llamabro.sdk.model.onEachLoading +import com.suhel.llamabro.sdk.LlamaEngine +import com.suhel.llamabro.sdk.config.* +import com.suhel.llamabro.sdk.models.* +import com.suhel.llamabro.sdk.format.PromptFormats +import com.suhel.llamabro.sdk.model.* -// Declarative flow composition (auto-cleanup on cancellation) lifecycleScope.launch { LlamaEngine.createFlow( - ModelConfig( - modelPath = "/path/to/model.gguf", - promptFormat = PromptFormats.ChatML, + ModelDefinition( + loadConfig = ModelLoadConfig(path = "/path/to/model.gguf"), + promptFormat = PromptFormats.CHAT_ML, ) ) - .onEachLoading { progress -> - updateProgressBar(progress ?: 0f) + .onEachLoading { progress -> + progressBar.progress = ((progress ?: 0f) * 100).toInt() } .flatMapResource { engine -> engine.createSessionFlow( SessionConfig( contextSize = 4096, - overflowStrategy = OverflowStrategy.RollingWindow(500), - inferenceConfig = InferenceConfig( - temperature = 0.7, - repeatPenalty = 1.15, - ) + overflowStrategy = OverflowStrategy.RollingWindow(dropTokens = 500), + inferenceConfig = InferenceConfig(temperature = 0.7f, minP = 0.1f) ) ) } .flatMapResource { session -> session.createChatSessionFlow("You are a helpful assistant.") } - .filterSuccess() // Extract chat session, drop Loading/Failure - .flatMapLatest { chatSession -> - chatSession.completion("Explain coroutines in one paragraph.") + .filterSuccess() + .flatMapLatest { chat -> + chat.completion(ChatEvent.UserEvent("Explain coroutines.", think = false)) } - .collect { completion -> - updateTextView(completion.contentText.orEmpty()) - - if (completion.isComplete) { - logPerformance("${completion.tokensPerSecond} tokens/sec") + .collect { snapshot -> + textView.text = snapshot.message.text + if (snapshot.isComplete) { + speedLabel.text = "${snapshot.tokensPerSecond} tok/s" } } } ``` -That's it. No callbacks. No manual resource management. Flow handles cleanup when cancelled. - ---- - -## Use Cases - -**Privacy-First Chat** -Build conversational features for health, finance, or sensitive domains without worrying about data residency. - -**Offline Assistants** -Code editor plugins, keyboard assistants, or writing tools that work on a flight. - -**Real-Time Reasoning** -Run models that can think step-by-step (DeepSeek-R1, QwQ) and extract their reasoning for debugging or transparency. - -**Reducing Latency** -No round-trip to a remote server. Get responses in milliseconds, not seconds. - --- ## API Overview -Llama Bro's API is tiered by abstraction level. Use what you need: +The SDK is layered. Each tier adds abstraction. Use what your use case demands. -### `LlamaEngine` — Model loading and session factory +### `LlamaEngine` — The Model Loader -Responsibility: Load the GGUF file, manage model weights, create sessions. +Loads the GGUF file and manages model weights. Creates sessions on demand. Keep **one engine per model** across the app. ```kotlin -// Recommended: Flow-based (auto-cleanup on cancellation) -LlamaEngine.createFlow(modelConfig) - .collect { resourceState -> /* handle loading/success/error */ } - -// Manual: Explicit lifecycle control -val engine = LlamaEngine.create(modelConfig) { progress -> - updateProgressBar(progress) - true // Return false to cancel load -} +// Recommended: Flow-based (auto-cleanup on coroutine cancellation) +LlamaEngine.createFlow(modelDefinition) + .onEachLoading { progress -> showProgress(progress) } + .flatMapResource { engine -> /* use engine */ } +// Manual: You manage the lifecycle +val engine = LlamaEngine.create(modelDefinition) { progress -> true /* return false to cancel */ } val session = engine.createSession(sessionConfig) -// ... use session ... -engine.close() // Release model memory +engine.close() // Releases native memory ``` -**When to use:** Always start here. Keep one engine per model. +### `LlamaSession` — The Token Engine -### `LlamaSession` — Token-level inference - -Responsibility: Manage KV cache, encode/decode tokens, sample. +Manages the KV cache, token encoding, and sampling. Mutex-serialized for thread safety. ```kotlin -// Full control: manually encode, then generate tokens one by one -suspend fun manualInference(session: LlamaSession) { - session.setSystemPrompt("You are helpful.") - session.prompt("What's 2+2?") - - val output = StringBuilder() - while (true) { - val token = session.generate() ?: break // null = EOS - output.append(token) - } - return output.toString() +// Use the Flow API for standard sampling +session.generateFlow().collect { result -> + print(result.token ?: "") + if (result.isComplete) return@collect } ``` -**When to use:** Building custom sampling loops, token-level debugging, or advanced inference control. Most use cases don't need this. +> **When to use directly:** Implementing custom sampling loops, tool injection, or token-level diagnostics. Most apps should use `LlamaChatSession` instead. -### `LlamaChatSession` — High-level conversation API +### `LlamaChatSession` — The Chat API -Responsibility: Format messages, handle stop tokens, extract thinking blocks, compute metrics. +Handles prompt formatting, stop-token detection, thinking-block extraction, and metrics. This is where 95% of integrations start and end. ```kotlin -// Simple: One-liner for chat completions -chat.completion("Explain coroutines.").collect { completion -> - println(completion.contentText) +chat.completion(ChatEvent.UserEvent("Hello!", think = true)).collect { snapshot -> + // snapshot.message.text → Visible response + // snapshot.message.thinkingText → Hidden reasoning + // snapshot.tokensPerSecond → Generation speed + // snapshot.isComplete → True when done } ``` -**When to use:** 95% of use cases. Handles formatting, thinking block extraction, metrics. - --- -**All operations are `suspend` functions or `Flow`—no callbacks, no blocking threads.** +## Configuration Reference ---- - -## Configuration +### `ModelDefinition` -### Model Config +The root configuration object. Bundles load settings with the prompt format. -| Option | Default | Purpose | -|----------------|---------------------------|--------------------------------------------------------------------| -| `modelPath` | required | Absolute path to `.gguf` model file | -| `promptFormat` | required | Chat template (`Llama3`, `Gemma3`, `ChatML`, `Mistral`, or custom) | -| `threads` | `availableProcessors / 2` | Inference thread count; tune for performance cores | -| `useMmap` | `true` | Memory-map model file (faster loading, lower peak RAM) | -| `useMlock` | `false` | Lock model in RAM (prevents swapping; use on capable devices only) | - -**Example:** ```kotlin -ModelConfig( - modelPath = "/data/models/gemma-3n.gguf", - promptFormat = PromptFormats.Gemma3, - threads = 8, // Match your device's performance core count - useMmap = true, - useMlock = false +ModelDefinition( + loadConfig = ModelLoadConfig( + path = "/data/user/0/com.myapp/files/model.gguf", + threads = 8, // Match your device's performance core count + useMMap = true, // Memory-map the file (recommended) + useMLock = false // Lock in RAM (prevent OS swap — high-memory devices only) + ), + promptFormat = PromptFormats.LLAMA_3, + features = listOf(ThinkingMarker) // Enable thinking injection for reasoning models ) ``` -### Session Config - -| Option | Default | Purpose | -|--------------------|----------------------|-------------------------------------------------| -| `contextSize` | `4096` | KV cache size in tokens (max prompt + response) | -| `overflowStrategy` | `RollingWindow(500)` | Behavior when cache is exhausted | -| `inferenceConfig` | (see below) | Sampling parameters | -| `decodeConfig` | (see below) | Batch decoding tuning | -| `seed` | `-1` (random) | Set to a fixed value for reproducibility | +### `SessionConfig` -### Inference Config (Sampling) +| Option | Default | Notes | +|--------------------|----------------------|--------------------------------------------------------------------| +| `contextSize` | `2048` | Token budget for the entire conversation (prompt + response) | +| `overflowStrategy` | `RollingWindow(500)` | What happens when the KV cache fills up | +| `inferenceConfig` | See below | Sampling parameters | +| `decodeConfig` | See below | I/O batch sizes for performance tuning | +| `seed` | `-1` (random) | Set an integer for reproducible outputs | -Control token generation quality and diversity: +### `InferenceConfig` — Sampling -| Option | Default | Range | Purpose | -|-------------------|---------|-----------|-------------------------------------------------------------------------| -| `temperature` | `0.8` | `0.0–2.0` | Sampling creativity; `0.0` = greedy (always pick best), `1.0` = neutral | -| `repeatPenalty` | `1.1` | `1.0–2.0` | Penalize recent tokens to avoid loops | -| `presencePenalty` | `0.0` | `0.0–2.0` | Penalize all previously seen tokens | -| `minP` | `0.1` | `0.0–1.0` | Min-probability threshold; `null` to disable | -| `topP` | `null` | `0.0–1.0` | Nucleus sampling; `null` = disabled | -| `topK` | `null` | `1–∞` | Top-K sampling; `null` = disabled | +| Option | Default | Range | Effect | +|-------------------|---------|-----------|--------------------------------------------------------------------| +| `temperature` | `0.8f` | `0.0–2.0` | Randomness. `0.0` = deterministic greedy, `1.0` = neutral. | +| `repeatPenalty` | `1.0f` | `1.0–2.0` | Discourages the model from repeating recent tokens. | +| `presencePenalty` | `0.0f` | `0.0–2.0` | Penalizes all tokens that have appeared, not just recent ones. | +| `minP` | `0.1f` | `0.0–1.0` | Min-probability filter. Cuts "hallucination tail" tokens cleanly. | +| `topP` | `null` | `0.0–1.0` | Nucleus sampling. `null` = disabled. | +| `topK` | `null` | `1–∞` | Top-K sampling. `null` = disabled. | -**Example:** -```kotlin -InferenceConfig( - temperature = 0.7, // Slightly conservative - repeatPenalty = 1.15, // Discourage repetition - minP = 0.05, // Filter low-probability tokens - topP = 0.9 // Nucleus sampling for diversity -) -``` +### `DecodeConfig` — Performance -### Decode Config (Performance Tuning) +| Option | Default | Notes | +|------------------|---------|------------------------------------------------------| +| `batchSize` | `2048` | Max tokens processed per decode step. | +| `microBatchSize` | `512` | Internal chunking granularity. Lower = less RAM. | -| Option | Default | Purpose | -|-----------------------|---------|-----------------------------------------------| -| `batchSize` | `2048` | Max tokens processed per decode step | -| `microBatchSize` | `512` | Internal chunking for memory efficiency | -| `systemPromptReserve` | `100` | Tokens reserved for system prompt in rollover | - -Increase `batchSize` to `4096` for faster prefill on long system prompts; decrease if memory-constrained. +Increase `batchSize` to `4096` for faster long-prompt prefill. Reduce it on RAM-constrained devices. ### Overflow Strategies -When the KV cache reaches `contextSize`, select one: - -- **`Halt`** — Throw `LlamaError.ContextOverflow`. Use for strict determinism. -- **`ClearHistory`** — Discard all prior messages, reload system prompt, continue. -- **`RollingWindow(dropTokens)`** — Evict oldest `dropTokens` tokens, keep chatting (recommended). - -**Example:** -```kotlin -SessionConfig( - contextSize = 4096, - overflowStrategy = OverflowStrategy.RollingWindow(dropTokens = 500) -) -``` - ---- - -## Supported Models - -Llama Bro works with any GGUF model that runs in `llama.cpp`. Built-in templates cover the major families: - -| Template | Format | Recommended Models | Size Range | -|-----------|------------------------------------------|--------------------------------|----------------------------| -| `Gemma3` | `` / `` | Gemma 3, Gemma 3n | 2B–27B (Q4_K_M: 1.5–16 GB) | -| `Llama3` | `<\|start_header_id\|>` / `<\|eot_id\|>` | Llama 3 / 3.1 / 3.2 / 3.3 | 8B–70B (Q4_K_M: 5–40 GB) | -| `ChatML` | `<\|im_start\|>` / `<\|im_end\|>` | Qwen 2.5, Yi, InternLM, Hermes | 1B–72B (varies) | -| `Mistral` | `[INST]` / `[/INST]` | Mistral 7B, Mixtral 8x7B | 7B–46B (Q4_K_M: 4.5–30 GB) | - -**Finding models:** Browse [Hugging Face](https://huggingface.co/models?library=gguf) for GGUF quantisations. **Starting recommendation:** Gemma 3n (2B, ~3 GB) or Llama 3.2 (1B, ~600 MB) for testing. - -**Custom templates:** For models not listed above, define your own: - -```kotlin -val custom = PromptFormat( - systemPrefix = "<>\n", - systemSuffix = "\n<>\n\n", - userPrefix = "[INST] ", - userSuffix = " [/INST]", - assistantPrefix = "", - assistantSuffix = "" -) - -LlamaEngine.create( - ModelConfig(modelPath = "/path/to/model.gguf", promptFormat = custom) -) -``` - -**How to find the right template:** Check the model's Hugging Face card or README for the "chat template" field. It usually tells you the exact markers and order. - ---- - -## ResourceState Flow Extensions - -Llama Bro provides declarative flow operators to compose resource lifecycles cleanly: - -```kotlin -// Extract success value or null -val engine: LlamaEngine? = resourceState.getOrNull() - -// Transform success value, preserving Loading/Failure -val mapped: ResourceState = resourceState.map { it.toString() } - -// Transform success values in a flow -engineFlow.mapSuccess { engine -> MyWrapper(engine) } - -// Chain sequential resource flows (Engine → Session → Chat) -engineFlow - .flatMapResource { engine -> engine.createSessionFlow(config) } - .flatMapResource { session -> session.createChatSessionFlow("System") } - .filterSuccess() // Emit only loaded chat sessions - .collect { chat -> /* use chat */ } - -// React to success without transforming the value -engineFlow.onEachSuccess { engine -> - updateProgressBar(1.0f) // Engine loaded -} - -// Extract only successful values from a flow -chatFlow.filterSuccess() // Flow instead of Flow> - -// Exhaustive pattern matching -resourceState.fold( - onLoading = { progress -> showLoadingUI(progress ?: 0f) }, - onSuccess = { value -> showSuccessUI(value) }, - onFailure = { error -> showErrorUI(error) } -) - -// Extract value or supply a default on error/loading -val engine = resourceState.getOrElse { error -> fallbackEngine } -``` - -These operators enable **declarative, type-safe composition** without nested `when` blocks. +| Strategy | Behavior | Best For | +|--------------------------|------------------------------------------------------------------|---------------------------------------| +| `Halt` | Throws `LlamaError.ContextOverflow` | Strict determinism, batch processing | +| `ClearHistory` | Wipes context, reloads system prompt, continues | Short-session apps | +| `RollingWindow(n)` | Evicts oldest `n` tokens, keeps chatting | Long conversational flows (recommended) | --- ## Thinking Blocks & Reasoning Models -Reasoning models like DeepSeek-R1 and QwQ expose internal thoughts via `...` tags. Llama Bro extracts them automatically: +Reasoning models like **DeepSeek-R1** and **QwQ** expose their internal chain-of-thought inside `...` tags. Llama Bro automatically extracts these into a separate part of the `AssistantEvent`. ```kotlin -chat.completion("Hard problem").collect { completion -> - // View the model's reasoning - completion.thinkingText?.let { thinking -> - println("Model reasoning:\n$thinking") - } - - // View the final answer - println("Final answer:\n${completion.contentText}") - - // Check performance - if (completion.isComplete) { - println("Generated at ${completion.tokensPerSecond} tokens/sec") - } -} -``` - -Perfect for explainability, debugging, or understanding complex model behavior. The `Completion` data class provides: - -- **`thinkingText`** — Content inside `...` blocks (reasoning models only) -- **`contentText`** — Visible response text (everything outside thinking blocks) -- **`tokensPerSecond`** — Streaming performance metric -- **`isComplete`** — True when generation ends (EOS reached) - ---- - -## Architecture - -Llama Bro is a clean, layered stack: - -``` -┌────────────────────────────────┐ -│ LlamaChatSession │ High-level chat API -│ (formatting, stop detection) │ -├────────────────────────────────┤ -│ LlamaSession │ Token-level control -│ (mutex-serialized, abort-safe) │ -├────────────────────────────────┤ -│ LlamaEngine │ Model loader -│ (ResourceState lifecycle) │ -├────────────────────────────────┤ -│ JNI Bridge │ Kotlin ↔ C++ -│ (error codes, callbacks) │ -├────────────────────────────────┤ -│ C++ Engine (llama.cpp) │ GGML, SIMD backends -│ (NEON, dotprod, i8mm, SVE) │ -└────────────────────────────────┘ -``` - -All native pointers are wrapped in `AutoCloseable` interfaces. Cancellation is safe. Leaks are prevented. - ---- - -## Performance Tips - -**Benchmark:** ~20 tokens/second on OnePlus 13 (Gemma 3n 2B, Q4_K_M). - -**Tune for your device:** - -1. Use `Q4_K_M` quantisation—best quality-to-speed tradeoff on mobile. -2. Set `threads` to match your device's performance core count. -3. Keep `useMmap = true` (default) to avoid loading the full model into RAM. -4. Increase `decodeConfig.batchSize` to `4096` for faster prefill on long prompts. -5. Models > 4 GB may need `useMlock = false` to avoid out-of-memory on mid-range phones. - ---- - -## Completion Streaming - -Each emission from `chat.completion(message)` is a `Completion` snapshot: - -```kotlin -data class Completion( - val thinkingText: String?, // Internal reasoning (reasoning models only) - val contentText: String?, // Visible response text - val tokensPerSecond: Float?, // Performance metric (rolling average) - val isComplete: Boolean // True when EOS reached (generation done) +// Set think = true to inject the opening tag, +// forcing the model into reasoning mode. +val userEvent = ChatEvent.UserEvent( + content = "What is 17 × 23? Show your work.", + think = true ) -``` -The flow emits cumulative snapshots—each one contains all tokens generated so far, not just new tokens. Use `isComplete` to detect end-of-generation: +chat.completion(userEvent).collect { snapshot -> + // Display reasoning in a collapsible section + val reasoning = snapshot.message.thinkingText // "Let me calculate 17 × 23..." + val answer = snapshot.message.text // "The answer is 391." -```kotlin -chat.completion("Hello").collect { completion -> - if (!completion.isComplete) { - // Still generating - updateUI(completion.contentText) // Partial response - } else { - // Generation finished - saveToDatabase(completion) - logMetrics("Final speed: ${completion.tokensPerSecond} t/s") + if (snapshot.isComplete) { + println("${snapshot.tokensPerSecond} tokens/sec") } } ``` ---- - -## Conversation History - -Restore prior chats from a database: - -```kotlin -val history = listOf( - Message.User("What's Kotlin?"), - Message.Assistant("Kotlin is a JVM language..."), - Message.User("And coroutines?") -) - -chat.loadHistory(history) -chat.completion("Explain together").collect { /* ... */ } -``` - -The session's KV cache is pre-populated with the history, so the next response is contextual and faster. +The `think = true` parameter only works on models with `ThinkingMarker` in their `ModelDefinition.features`. On non-thinking models, it is silently ignored — making the API safe to use unconditionally. --- ## Error Handling -All errors cross the JNI boundary as typed exceptions: +All native failures cross the JNI boundary as typed `LlamaError` subtypes: ```kotlin -sealed class LlamaError : Exception { - class ModelNotFound(val path: String) - class ModelLoadFailed(val path: String, cause: Throwable?) - class BackendLoadFailed(val backendName: String) - class ContextInitFailed(cause: Throwable?) - class ContextOverflow - class DecodeFailed(val code: Int) - class NativeException(val nativeMessage: String, cause: Throwable?) +sealed class LlamaError : Exception() { + class ModelNotFound(val path: String) : LlamaError() + class ModelLoadFailed(val path: String, cause: Throwable?) : LlamaError() + class BackendLoadFailed(val backendName: String) : LlamaError() + class ContextInitFailed(cause: Throwable?) : LlamaError() + class ContextOverflow : LlamaError() + class DecodeFailed(val code: Int) : LlamaError() + class NativeException(val nativeMessage: String, cause: Throwable?) : LlamaError() } ``` -Handle them with idiomatic Kotlin: +Compose error recovery into the same flow chain: ```kotlin -LlamaEngine.createFlow(modelConfig) +LlamaEngine.createFlow(modelDefinition) .catch { e -> when (e) { - is LlamaError.ModelNotFound -> showFilePicker() - is LlamaError.ContextOverflow -> handleFullContext() - else -> throw e + is LlamaError.ModelNotFound -> showModelPickerUI() + is LlamaError.ContextOverflow -> onContextFull() + else -> logAndRethrow(e) } } .collect { /* ... */ } @@ -527,300 +396,126 @@ LlamaEngine.createFlow(modelConfig) --- -## Examples +## Conversation History -### Example 1: ViewModel with Declarative Flow Composition +Re-populate the KV cache with a prior conversation before the next turn: ```kotlin -@HiltViewModel -class ChatViewModel @Inject constructor( - modelRepository: ModelRepository // Injected model source -) : ViewModel() { - - private val modelPath = "/data/models/gemma-3n.gguf" // Inject through adb/Android file explorer for testing +val history: List = listOf( + ChatEvent.UserEvent("What's Kotlin?", think = false), + ChatEvent.AssistantEvent(listOf( + ChatEvent.AssistantEvent.Part.TextPart("Kotlin is a JVM language by JetBrains.") + )), + ChatEvent.UserEvent("And coroutines?", think = false), +) - // Single declarative flow chain for the entire lifecycle - private val chatSessionFlow = LlamaEngine.createFlow( - ModelConfig(modelPath, PromptFormats.Gemma3) - ) - .onEachSuccess { _loadingProgress.value = 1.0f } // Model ready - .flatMapResource { engine -> - engine.createChatSessionFlow( - systemPrompt = "You are a helpful assistant.", - sessionConfig = SessionConfig( - contextSize = 4096, - overflowStrategy = OverflowStrategy.RollingWindow(500) - ) - ) - } - .filterSuccess() // Only emit loaded chat sessions - .stateIn(viewModelScope, SharingStarted.Lazily, null) - - fun sendMessage(userMessage: String): Flow = - chatSessionFlow - .filterNotNull() - .flatMapLatest { chat -> - chat.completion(userMessage) - .map { it.contentText.orEmpty() } - } -} +chat.feedHistory(history) +chat.completion(ChatEvent.UserEvent("Give me an example.", think = false)) + .collect { snapshot -> /* ... */ } ``` -No manual `onCleared()` cleanup needed—the flow scoping handles lifecycle automatically. +The session processes history tokens once — subsequent context is pre-warmed and generations are faster. -### Example 2: Real-Time Streaming to UI +--- -```kotlin -@HiltViewModel -class ResponseViewModel @Inject constructor() : ViewModel() { - private val _uiState = MutableStateFlow(ResponseUiState.Idle) - val uiState = _uiState.asStateFlow() - - fun generateResponse(message: String) { - viewModelScope.launch { - chatSession.completion(message) - .onStart { _uiState.value = ResponseUiState.Loading } - .collect { completion -> - _uiState.value = ResponseUiState.Streaming( - text = completion.contentText.orEmpty(), - tokensPerSecond = completion.tokensPerSecond, - thinking = completion.thinkingText, - isComplete = completion.isComplete - ) - } - } - } -} +## ResourceState Flow Operators -// In Compose or XML Fragment -uiState.collect { state -> - when (state) { - is ResponseUiState.Streaming -> { - Text(state.text) // Updates in real-time - if (state.thinking != null) { - CollapsibleThinkingBlock(state.thinking) - } - if (state.isComplete) { - PerformanceLabel("${state.tokensPerSecond} tokens/sec") - } - } - ResponseUiState.Loading -> LoadingSpinner() - ResponseUiState.Idle -> {} - } +`ResourceState` is the lifecycle ADT powering the entire SDK: + +```kotlin +sealed class ResourceState { + data class Loading(val progress: Float?) : ResourceState() + data class Success(val value: T) : ResourceState() + data class Failure(val error: Throwable) : ResourceState() } ``` -### Example 3: Extracting Reasoning from DeepSeek-R1 +Compose resource flows declaratively using built-in operators: -```kotlin -// Use with reasoning models (DeepSeek-R1, QwQ) -val engine = LlamaEngine.create( - ModelConfig( - modelPath = "/data/models/deepseek-r1-7b-q4.gguf", - promptFormat = PromptFormats.ChatML - ) -) +| Operator | Use | +|---|---| +| `flatMapResource { }` | Chain a resource-loading step onto an existing one | +| `filterSuccess()` | Strip the wrapper, emit only successful values as `Flow` | +| `onEachLoading { }` | React to progress without leaving the chain | +| `onEachSuccess { }` | Side-effect on load completion | +| `mapSuccess { }` | Transform the inner value | +| `fold(onLoading, onSuccess, onFailure)` | Exhaustive pattern match | +| `getOrNull()` | Extract value or `null` | +| `getOrElse { }` | Extract value or a fallback | -val chat = engine.createSession(SessionConfig()).createChatSession("You are a math tutor.") - -chat.completion("Solve: 2^10 + 3^4 - 5^2").collect { completion -> - if (!completion.isComplete) return@collect // Wait for final token +--- - val reasoning = completion.thinkingText ?: "" - val answer = completion.contentText ?: "" +## Architecture - println("=== Model's Reasoning ===") - println(reasoning.take(500) + "...") // First 500 chars +```text +┌──────────────────────────────────┐ +│ LlamaChatSession (Public API) │ Formatting, stop tokens, metrics +├──────────────────────────────────┤ +│ LlamaSession (Public API) │ KV cache, mutex, token control +├──────────────────────────────────┤ +│ LlamaEngine (Public API) │ Model loading, session factory +├──────────────────────────────────┤ +│ JNI Bridge (Internal) │ C++ ↔ Kotlin, error mapping +├──────────────────────────────────┤ +│ llama.cpp (Native C++) │ GGML, SIMD (NEON, dotprod, i8mm) +└──────────────────────────────────┘ +``` - println("\n=== Final Answer ===") - println(answer) +All concrete implementations are `internal`. The public surface is interface-based. Extensions and wrappers can depend on the interfaces without coupling to the implementation. - println("\nPerformance: ${completion.tokensPerSecond} tokens/sec") -} -``` +--- -The thinking block is stripped from the visible response automatically—you get both simultaneously. +## Custom Prompt Formats -### Example 4: Custom Prompt Format for Unsupported Models +Any model not in the built-in list can be supported with a custom `PromptFormat`: ```kotlin -// For models not in the built-in templates, define your own format -val customFormat = PromptFormat( +val custom = PromptFormat( systemPrefix = "<>\n", - systemSuffix = "\n<>\n\n", userPrefix = "[INST] ", - userSuffix = " [/INST]\n", - assistantPrefix = "", - assistantSuffix = "\n" + assistantPrefix = "[/INST] ", + endOfTurn = "\n", + emitAssistantPrefixOnGeneration = true ) -val modelConfig = ModelConfig( - modelPath = "/data/models/my-custom-model.gguf", - promptFormat = customFormat +LlamaEngine.createFlow( + ModelDefinition( + loadConfig = ModelLoadConfig("/path/to/model.gguf"), + promptFormat = custom + ) ) - -// Use it like any built-in template -val engine = LlamaEngine.create(modelConfig) -val session = engine.createSession(SessionConfig()) -val chat = session.createChatSession("You are helpful.") - -// Test it with a known good prompt to verify format correctness -chat.completion("Test message").collect { completion -> - if (completion.contentText == null || completion.contentText.isEmpty()) { - println("⚠️ Format may be incorrect; model produced no output") - } else { - println("✓ Format working: ${completion.contentText}") - } -} -``` - -**Tip:** If responses are garbled or empty, double-check that the prefixes/suffixes match the model's training format exactly. Check the model's repository or card for the correct template. - ---- - -## Building from Source - -Clone with the llama.cpp submodule: - -```bash -git clone --recursive https://github.com/whyisitworking/llama-bro.git -cd llama-bro -``` - -Build the SDK: - -```bash -./gradlew :sdk:assembleRelease -``` - -Run tests: - -```bash -./gradlew :sdk:test ``` -### Build Requirements - -| Tool | Version | -|-------------|---------------| -| JDK | 17+ | -| Android SDK | API 36 | -| NDK | 29.0.14206865 | -| CMake | 3.22.1+ | - -NDK and CMake are installed automatically via the Android SDK manager. - ---- - -## Native Dependencies - -The SDK embeds `llama.cpp` as a vendored Git submodule. The CMake build compiles it directly into the `llama_bro` shared library. Key flags: - -- **`GGML_OPENCL = OFF`** — No GPU drivers; buggy, prevents UI stalls. -- **`GGML_OPENMP = ON`** — Multi-threaded decode. -- **`GGML_CPU_ALL_VARIANTS = ON`** — All CPU backends; best one selected at runtime. - -**Supported ABI:** `arm64-v8a` only. - ---- - -## Limitations & Roadmap - -### Current Limitations - -- **arm64-v8a only.** No x86_64 emulator support yet. -- **No GPU acceleration.** OpenCL is intentionally disabled, as it causes stalls. -- **Models must be local.** The library doesn't download; you manage model acquisition. -- **No multimodal.** Vision/audio models are not yet supported. -- **Memory constraints.** Full model must fit in RAM; typically limits practical use to 7B Q4 or smaller. - -### Roadmap - -- [ ] Streaming grammar / JSON-mode -- [ ] Function calling / tool use -- [ ] GGUF metadata reading -- [ ] LoRA adapter support - --- -## ProGuard / R8 - -Consumer ProGuard rules are built-in. The AAR automatically preserves: +## Roadmap -- JNI-accessible classes -- Native method signatures -- Fields read reflectively by JNI -- Callbacks invoked from native code - -No additional configuration required in your app. - ---- - -## Best Practices - -**Keep one engine per model.** Creating multiple engines for the same model wastes memory. Reuse the engine to create many sessions. - -**Use flow-based APIs for lifecycle safety.** `LlamaEngine.createFlow()` and `createSessionFlow()` automatically clean up resources on cancellation. Manual `.create()` requires explicit `.close()`. - -**Set `threads` to match performance cores.** Use `Runtime.availableProcessors() / 2` as a starting point, then benchmark. Too many threads causes contention; too few leaves cores idle. - -**Use `useMmap = true` by default.** Memory-mapping reduces peak RAM and doesn't hurt performance. Disable only if you have specific memory constraints. - -**Test your prompt format.** If responses are empty or gibberish, your `PromptFormat` is wrong. Try the model's README or a format from a similar model. - -**Handle `ContextOverflow` gracefully.** Use `OverflowStrategy.RollingWindow` for long conversations, or implement a "new conversation" flow when `LlamaError.ContextOverflow` is thrown. - -**Cache thinking blocks.** If using reasoning models, save `completion.thinkingText` separately from `contentText` for debugging or audit trails. - ---- - -## Troubleshooting - -| Symptom | Cause | Fix | -|--------------------------------------------|----------------------------------------------|---------------------------------------------------------------------------------------------------| -| Model loads but generation is **slow** | Too few threads; wrong quantisation | Increase `threads` to match core count; use Q4_K_M | -| **Out of memory (OOM)** on load | Model too large; `useMlock = true` | Use smaller model/quantisation; disable `useMlock` | -| **Blank or gibberish** responses | `PromptFormat` mismatch | Check model card for correct template; try built-in formats | -| **Crash on load** | File path wrong or file missing | Verify path exists; check file permissions | -| **Very slow prefill** on long prompts | Batch size too small | Increase `decodeConfig.batchSize` to `4096` | -| **Reasoning models** producing no thinking | Thinking not enabled in model config | Ensure model is trained with thinking blocks (DeepSeek-R1, QwQ); check `` tags in response | -| **Flow never completes** | Resource leak or cancellation not propagated | Use `.takeUntil()` to cancel flows; ensure lifecycle scope is cancellable | +- [ ] **Streaming Grammar** — Force structured JSON/function output from any model +- [ ] **Function Calling** — Registered tools that models can invoke during generation +- [ ] **Multi-Model Sessions** — Seamlessly switch models mid-conversation +- [ ] **GGUF Metadata** — Auto-detect model type and recommended format from file headers --- ## Contributing -Contributions are welcome. Before opening a PR: - -1. Open an issue to discuss non-trivial changes. -2. Run the full test suite: `./gradlew :sdk:test` +1. [Open an issue](https://github.com/whyisitworking/llama-bro/issues) to discuss non-trivial changes first +2. Run tests before submitting: `./gradlew :sdk:testDebugUnitTest` 3. Build the release AAR: `./gradlew :sdk:assembleRelease` -4. Follow the Kotlin official style guide. -5. Keep native code minimal and well-commented. +4. Follow the [Kotlin coding conventions](https://kotlinlang.org/docs/coding-conventions.html) +5. Keep native code minimal and its intent clear + +See [CLAUDE.md](CLAUDE.md) for architecture deep-dive and build setup. --- ## License -``` -Copyright 2024 whyisitworking - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -``` - -The embedded `llama.cpp` library is distributed under the [MIT License](https://github.com/ggml-org/llama.cpp/blob/master/LICENSE). +[Apache 2.0](LICENSE) --- -**What's next?** -Start with the [Quick Start](#quick-start). Download Gemma 3n. Build something. +

+ If Llama Bro saved you a weekend, give it a ⭐
+ Built with ❤️ for the Android + Local AI community. +

diff --git a/app/.gitignore b/app/.gitignore index 42afabf..65d12b9 100644 --- a/app/.gitignore +++ b/app/.gitignore @@ -1 +1,2 @@ -/build \ No newline at end of file +/build +google-services.json \ No newline at end of file diff --git a/app/build.gradle.kts b/app/build.gradle.kts index 006f471..8c44735 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -4,6 +4,8 @@ plugins { alias(libs.plugins.kotlin.ksp) alias(libs.plugins.hilt) alias(libs.plugins.kotlinx.serialization) + alias(libs.plugins.google.services) + alias(libs.plugins.firebase.crashlytics) } val dynamicVersionName = (project.findProperty("VERSION_NAME") as? String) ?: "1.0.0-SNAPSHOT" @@ -111,6 +113,11 @@ dependencies { // Markdown implementation(libs.compose.markdown) + // Firebase + implementation(platform(libs.firebase.bom)) + implementation(libs.firebase.analytics) + implementation(libs.firebase.crashlytics) + // SDK implementation(project(":sdk")) diff --git a/app/src/main/java/com/suhel/llamabro/demo/data/repository/ModelRepository.kt b/app/src/main/java/com/suhel/llamabro/demo/data/repository/ModelRepository.kt index 81e0570..57ec71f 100644 --- a/app/src/main/java/com/suhel/llamabro/demo/data/repository/ModelRepository.kt +++ b/app/src/main/java/com/suhel/llamabro/demo/data/repository/ModelRepository.kt @@ -6,8 +6,9 @@ import com.suhel.llamabro.demo.model.CurrentInferenceContext import com.suhel.llamabro.demo.model.Model import com.suhel.llamabro.demo.model.ModelDownloadState import com.suhel.llamabro.demo.model.ModelZoo -import com.suhel.llamabro.sdk.LlamaEngine -import com.suhel.llamabro.sdk.model.ModelConfig +import com.suhel.llamabro.sdk.engine.LlamaEngine +import com.suhel.llamabro.sdk.config.ModelDefinition +import com.suhel.llamabro.sdk.config.ModelLoadConfig import dagger.hilt.android.qualifiers.ApplicationContext import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers @@ -93,10 +94,11 @@ class ModelRepository @Inject constructor( .flatMapLatest { action -> when (action) { is Action.Load -> LlamaEngine.createFlow( - ModelConfig( - modelPath = action.model.file().absolutePath, - promptFormat = action.model.promptFormat, - supportsThinking = action.model.thinkingSupported + ModelDefinition( + loadConfig = ModelLoadConfig( + path = action.model.file().absolutePath + ), + promptFormat = action.model.promptFormat ) ).map { engine -> CurrentInferenceContext(action.model, engine) } diff --git a/app/src/main/java/com/suhel/llamabro/demo/model/CurrentInferenceContext.kt b/app/src/main/java/com/suhel/llamabro/demo/model/CurrentInferenceContext.kt index b2d8373..91c347e 100644 --- a/app/src/main/java/com/suhel/llamabro/demo/model/CurrentInferenceContext.kt +++ b/app/src/main/java/com/suhel/llamabro/demo/model/CurrentInferenceContext.kt @@ -1,6 +1,6 @@ package com.suhel.llamabro.demo.model -import com.suhel.llamabro.sdk.LlamaEngine +import com.suhel.llamabro.sdk.engine.LlamaEngine import com.suhel.llamabro.sdk.model.ResourceState data class CurrentInferenceContext( diff --git a/app/src/main/java/com/suhel/llamabro/demo/model/Model.kt b/app/src/main/java/com/suhel/llamabro/demo/model/Model.kt index 46a7e67..264a36d 100644 --- a/app/src/main/java/com/suhel/llamabro/demo/model/Model.kt +++ b/app/src/main/java/com/suhel/llamabro/demo/model/Model.kt @@ -1,7 +1,7 @@ package com.suhel.llamabro.demo.model -import com.suhel.llamabro.sdk.model.InferenceConfig -import com.suhel.llamabro.sdk.model.PromptFormat +import com.suhel.llamabro.sdk.config.InferenceConfig +import com.suhel.llamabro.sdk.format.PromptFormat data class Model( val id: String, diff --git a/app/src/main/java/com/suhel/llamabro/demo/model/ModelZoo.kt b/app/src/main/java/com/suhel/llamabro/demo/model/ModelZoo.kt index 6b422c0..269c31c 100644 --- a/app/src/main/java/com/suhel/llamabro/demo/model/ModelZoo.kt +++ b/app/src/main/java/com/suhel/llamabro/demo/model/ModelZoo.kt @@ -1,15 +1,26 @@ package com.suhel.llamabro.demo.model -import com.suhel.llamabro.sdk.model.InferenceConfig -import com.suhel.llamabro.sdk.model.PromptFormats +import com.suhel.llamabro.sdk.config.InferenceConfig +import com.suhel.llamabro.sdk.format.PromptFormats val ModelZoo = listOf( + Model( + id = "nemotron-mini-4b-instruct", + name = "Nemotron Mini 4B Instruct", + description = "Nemotron-Mini-4B-Instruct is a model for generating responses for roleplaying, retrieval augmented generation, and function calling.", + downloadUrl = "https://huggingface.co/bartowski/Nemotron-Mini-4B-Instruct-GGUF/resolve/main/Nemotron-Mini-4B-Instruct-Q6_K.gguf", + promptFormat = PromptFormats.NEMOTRON, + defaultInferenceConfig = InferenceConfig( + temperature = 1.0f, + topP = 1.0f, + ) + ), Model( id = "smollm2-135m-instruct", name = "SmolLM2 135M Instruct", description = "Ultra-lightweight model for absolute maximum tokens-per-second. Perfect for baseline speed tests.", downloadUrl = "https://huggingface.co/unsloth/SmolLM2-135M-Instruct-GGUF/resolve/main/SmolLM2-135M-Instruct-F16.gguf", - promptFormat = PromptFormats.ChatML, + promptFormat = PromptFormats.CHAT_ML, defaultInferenceConfig = InferenceConfig( temperature = 0.6f, repeatPenalty = 1.15f, @@ -24,7 +35,7 @@ val ModelZoo = listOf( name = "SmolLM2 360M Instruct", description = "Highly efficient sub-0.5B model balancing sheer speed with improved coherence.", downloadUrl = "https://huggingface.co/HuggingFaceTB/SmolLM2-360M-Instruct-GGUF/resolve/main/smollm2-360m-instruct-q8_0.gguf", - promptFormat = PromptFormats.ChatML, + promptFormat = PromptFormats.CHAT_ML, defaultInferenceConfig = InferenceConfig( temperature = 0.7f, repeatPenalty = 1.15f, @@ -39,7 +50,7 @@ val ModelZoo = listOf( name = "Qwen2.5 0.5B Instruct", description = "Exceptional speed with strong multilingual support and structured JSON formatting capabilities.", downloadUrl = "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/qwen2.5-0.5b-instruct-q8_0.gguf", - promptFormat = PromptFormats.ChatML, + promptFormat = PromptFormats.CHAT_ML, defaultInferenceConfig = InferenceConfig( temperature = 0.7f, repeatPenalty = 1.05f, @@ -54,7 +65,7 @@ val ModelZoo = listOf( name = "Qwen3.5 2B", description = "Multimodal thinking model with advanced reasoning capabilities.", downloadUrl = "https://huggingface.co/unsloth/Qwen3.5-2B-GGUF/resolve/main/Qwen3.5-2B-UD-Q5_K_XL.gguf", - promptFormat = PromptFormats.ChatML, + promptFormat = PromptFormats.CHAT_ML, defaultInferenceConfig = InferenceConfig( temperature = 1.0f, topP = 0.95f, @@ -71,7 +82,7 @@ val ModelZoo = listOf( name = "Llama-3.2 1B Instruct", description = "Meta's highly optimized 1B mobile model. The industry standard for reliable on-device chat.", downloadUrl = "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q5_K_M.gguf", - promptFormat = PromptFormats.Llama3, + promptFormat = PromptFormats.LLAMA_3, defaultInferenceConfig = InferenceConfig( temperature = 0.6f, repeatPenalty = 1.1f, @@ -86,7 +97,7 @@ val ModelZoo = listOf( name = "DeepSeek-R1 1.5B (Distilled)", description = "Advanced reasoning capabilities on-device. Uses chain-of-thought processing.", downloadUrl = "https://huggingface.co/unsloth/DeepSeek-R1-Distill-Qwen-1.5B-GGUF/resolve/main/DeepSeek-R1-Distill-Qwen-1.5B-Q8_0.gguf", - promptFormat = PromptFormats.ChatML, + promptFormat = PromptFormats.CHAT_ML, defaultInferenceConfig = InferenceConfig( temperature = 0.6f, repeatPenalty = 1.0f, @@ -103,7 +114,7 @@ val ModelZoo = listOf( name = "SmolLM2 1.7B Instruct", description = "High-quality, nuanced text generation that punches above its weight class.", downloadUrl = "https://huggingface.co/bartowski/SmolLM2-1.7B-Instruct-GGUF/resolve/main/SmolLM2-1.7B-Instruct-Q5_K_M.gguf", - promptFormat = PromptFormats.ChatML, + promptFormat = PromptFormats.CHAT_ML, defaultInferenceConfig = InferenceConfig( temperature = 0.7f, repeatPenalty = 1.1f, @@ -114,3 +125,4 @@ val ModelZoo = listOf( ) ) ) + diff --git a/app/src/main/java/com/suhel/llamabro/demo/ui/screens/chat/ChatViewModel.kt b/app/src/main/java/com/suhel/llamabro/demo/ui/screens/chat/ChatViewModel.kt index 0f8f2e3..6b44a31 100644 --- a/app/src/main/java/com/suhel/llamabro/demo/ui/screens/chat/ChatViewModel.kt +++ b/app/src/main/java/com/suhel/llamabro/demo/ui/screens/chat/ChatViewModel.kt @@ -14,8 +14,8 @@ import com.suhel.llamabro.demo.data.repository.ChatRepository import com.suhel.llamabro.demo.data.repository.ModelRepository import com.suhel.llamabro.demo.model.MessageRole import com.suhel.llamabro.demo.navigation.Chat -import com.suhel.llamabro.sdk.model.Message -import com.suhel.llamabro.sdk.model.SessionConfig +import com.suhel.llamabro.sdk.models.ChatEvent +import com.suhel.llamabro.sdk.config.SessionConfig import com.suhel.llamabro.sdk.model.filterSuccess import com.suhel.llamabro.sdk.model.flatMapResource import com.suhel.llamabro.sdk.model.getOrNull @@ -53,6 +53,8 @@ class ChatViewModel @Inject constructor( """.trimIndent() private const val MAX_TITLE_LENGTH = 50 + + private const val STREAMING_ID = "streaming" } /** Null until the first message is sent (new conversation flow). */ @@ -89,18 +91,20 @@ class ChatViewModel @Inject constructor( val history = chatRepository.getMessages(id) .map { chatMessage -> when (chatMessage.role) { - MessageRole.User -> Message.User( - content = chatMessage.content + MessageRole.User -> ChatEvent.UserEvent( + content = chatMessage.content, + think = false ) - MessageRole.Assistant -> Message.Assistant( - content = chatMessage.content, - thinking = chatMessage.thinking + MessageRole.Assistant -> ChatEvent.AssistantEvent( + parts = listOf( + ChatEvent.AssistantEvent.Part.TextPart(chatMessage.content) + ) ) } } - chatSession.loadHistory(history) + chatSession.feedHistory(history) } .stateIn(viewModelScope, SharingStarted.Eagerly, null) @@ -186,30 +190,31 @@ class ChatViewModel @Inject constructor( .filterNotNull() .flatMapLatest { chatSession -> chatSession.completion( - prompt, - enableThinking, - model.defaultMaxThinkingTokens + ChatEvent.UserEvent( + content = prompt, + think = enableThinking + ) ) } .onEach { chunk -> - if (chunk.isComplete && chunk.error == null - && (chunk.contentText != null || chunk.thinkingText != null) + if (chunk.isComplete && !chunk.isError + && (chunk.message.text.isNotEmpty() || chunk.message.thinkingText.isNotEmpty()) ) { chatRepository.addMessage( conversationId = conversationId, role = MessageRole.Assistant, - content = chunk.contentText.orEmpty(), - thinking = chunk.thinkingText, + content = chunk.message.text, + thinking = chunk.message.thinkingText.takeIf { it.isNotEmpty() }, tokensPerSecond = chunk.tokensPerSecond ) } } .map { chunk -> when { - chunk.error != null -> UiChatMessage( + chunk.isError -> UiChatMessage( id = "streaming", role = MessageRole.Assistant, - error = chunk.error!!.message + error = chunk.error ) chunk.isComplete -> null @@ -217,8 +222,8 @@ class ChatViewModel @Inject constructor( else -> UiChatMessage( id = "streaming", role = MessageRole.Assistant, - content = chunk.contentText, - thinking = chunk.thinkingText + content = chunk.message.text.takeIf { it.isNotEmpty() }, + thinking = chunk.message.thinkingText.takeIf { it.isNotEmpty() } ) } } @@ -238,9 +243,9 @@ class ChatViewModel @Inject constructor( .stateIn(viewModelScope, SharingStarted.Eagerly, null) val inputConfig = combine( - chatSessionFlow.map { it?.supportsThinking == true }, + currentModelFlow.map { it?.thinkingSupported == true }, incomingMessage.map { it != null } - ) { (supportsThinking, isGenerating) -> + ) { supportsThinking, isGenerating -> UiChatInputConfig( thinkingSupported = supportsThinking, isGenerating = isGenerating diff --git a/build.gradle.kts b/build.gradle.kts index b16ed4b..e46e181 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -2,8 +2,10 @@ plugins { alias(libs.plugins.android.application) apply false alias(libs.plugins.kotlin.compose) apply false + alias(libs.plugins.kotlinx.serialization) apply false alias(libs.plugins.android.library) apply false alias(libs.plugins.kotlin.ksp) apply false alias(libs.plugins.hilt) apply false - alias(libs.plugins.kotlinx.serialization) apply false + alias(libs.plugins.google.services) apply false + alias(libs.plugins.firebase.crashlytics) apply false } \ No newline at end of file diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 9712882..909a6d0 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -22,6 +22,9 @@ hiltNavigation = "1.3.0" paging = "3.4.2" pagingCompose = "3.5.0-alpha01" composeMarkdown = "0.6.0" +googleServices = "4.4.4" +firebaseBom = "34.11.0" +firebaseCrashlytics = "3.0.6" [libraries] androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" } @@ -57,6 +60,9 @@ hilt-navigation-compose = { group = "androidx.hilt", name = "hilt-navigation-com androidx-paging-runtime = { group = "androidx.paging", name = "paging-runtime", version.ref = "paging" } androidx-paging-compose = { group = "androidx.paging", name = "paging-compose", version.ref = "pagingCompose" } compose-markdown = { group = "com.github.jeziellago", name = "compose-markdown", version.ref = "composeMarkdown" } +firebase-bom = { group = "com.google.firebase", name = "firebase-bom", version.ref = "firebaseBom" } +firebase-analytics = { group = "com.google.firebase", name = "firebase-analytics" } +firebase-crashlytics = { group = "com.google.firebase", name = "firebase-crashlytics" } [plugins] android-application = { id = "com.android.application", version.ref = "agp" } @@ -65,3 +71,6 @@ kotlinx-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", vers android-library = { id = "com.android.library", version.ref = "agp" } kotlin-ksp = { id = "com.google.devtools.ksp", version.ref = "ksp" } hilt = { id = "com.google.dagger.hilt.android", version.ref = "hilt" } +google-services = { id = "com.google.gms.google-services", version.ref = "googleServices" } +firebase-crashlytics = { id = "com.google.firebase.crashlytics", version.ref = "firebaseCrashlytics" } +maven-publish = { id = "maven-publish" } diff --git a/sdk/build.gradle.kts b/sdk/build.gradle.kts index fa7a8d4..f19a4c6 100644 --- a/sdk/build.gradle.kts +++ b/sdk/build.gradle.kts @@ -1,6 +1,7 @@ plugins { alias(libs.plugins.android.library) - id("maven-publish") + alias(libs.plugins.kotlinx.serialization) + alias(libs.plugins.maven.publish) } val dynamicVersionName = (project.findProperty("VERSION_NAME") as? String) ?: "1.0.0-SNAPSHOT" @@ -84,4 +85,6 @@ dependencies { testImplementation(libs.kotlinx.coroutines.test) androidTestImplementation(libs.androidx.junit) androidTestImplementation(libs.androidx.espresso.core) + + implementation(libs.kotlinx.serialization.json) } diff --git a/sdk/consumer-rules.pro b/sdk/consumer-rules.pro index d78d8d0..b0c59a7 100644 --- a/sdk/consumer-rules.pro +++ b/sdk/consumer-rules.pro @@ -19,6 +19,6 @@ } # Keep ProgressListener.onProgress — called from native code via JNI CallBooleanMethod --keepclassmembers class * implements com.suhel.llamabro.sdk.internal.ProgressListener { +-keepclassmembers class * implements com.suhel.llamabro.sdk.ProgressListener { boolean onProgress(float); } diff --git a/sdk/src/main/cpp/CMakeLists.txt b/sdk/src/main/cpp/CMakeLists.txt index b61b129..ba563d2 100644 --- a/sdk/src/main/cpp/CMakeLists.txt +++ b/sdk/src/main/cpp/CMakeLists.txt @@ -1,42 +1,36 @@ cmake_minimum_required(VERSION 3.22.1) project(llama_bro C CXX) -# Enforce C11 for GGML set(CMAKE_C_STANDARD 11) -set(CMAKE_C_STANDARD_REQUIRED ON) +set(CMAKE_C_STANDARD_REQUIRED true) -# Enforce C++17 for llama.cpp set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_STANDARD_REQUIRED true) -# Disable OpenCL as it causes UI freezes -set(GGML_OPENCL OFF CACHE BOOL "" FORCE) -set(GGML_OPENCL_EMBED_KERNELS OFF CACHE BOOL "" FORCE) -set(GGML_OPENCL_USE_ADRENO_KERNELS OFF CACHE BOOL "" FORCE) +# Needed for multiple CPU variants +set(GGML_BACKEND_DL ON) -# Need OpenMP for multithreading -set(GGML_OPENMP ON CACHE BOOL "" FORCE) +# CPU is the most stable backend on mobile now +set(GGML_CPU ON) -# Need multiple backends for runtime selection based on support -set(GGML_CPU_ALL_VARIANTS ON CACHE BOOL "" FORCE) +# Compile GGML for all CPU variants +set(GGML_CPU_ALL_VARIANTS ON) -# Need dynamic linkage for choosing backend -set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) -set(GGML_BACKEND_DL ON CACHE BOOL "" FORCE) +# OpenMP mismanages big.LITTLE architectures +set(GGML_OPENMP OFF) -# No point in building standalone -set(GGML_STANDALONE OFF CACHE BOOL "" FORCE) +# Use optimized kernels +set(GGML_CPU_KLEIDIAI ON) add_subdirectory(external/llama.cpp) add_library(llama_bro SHARED - engine.cpp - session.cpp + engine/engine.cpp + session/session.cpp - # The interops - jni/llama_engine_jni.cpp - jni/llama_session_jni.cpp + jni/engine.cpp + jni/session.cpp ) target_include_directories(llama_bro diff --git a/sdk/src/main/cpp/engine.cpp b/sdk/src/main/cpp/engine.cpp deleted file mode 100644 index dfb6af8..0000000 --- a/sdk/src/main/cpp/engine.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include "engine.h" -#include "utils/ggml_variant_chooser.h" -#include "utils/llama_exception.h" - -LlamaEngine::LlamaEngine(const NativeEngineParams &config) : threads{config.threads} { - auto backend_result = ggml_backend_load(resolve_best_ggml_backend()); - if (backend_result == nullptr) { - throw LlamaException(LlamaErrorCode::BACKEND_LOAD_FAILED, resolve_best_ggml_backend()); - } - - llama_backend_init(); - - auto params = llama_model_default_params(); - params.use_mmap = config.use_mmap; - params.use_mlock = config.use_mlock; - params.n_gpu_layers = 0; // CPU-only for now - - // Wire up progress callback if provided - if (config.progress_callback) { - // Stack-safe: std::function lives in config, which lives in the JNI frame. - // llama_model_load_from_file is synchronous — the callback is only called here. - params.progress_callback_user_data = const_cast *>(&config.progress_callback); - params.progress_callback = [](float progress, void *user_data) -> bool { - auto *cb = static_cast *>(user_data); - return (*cb)(progress); - }; - } - - auto *model = llama_model_load_from_file(config.model_path.c_str(), params); - - if (!model) { - llama_backend_free(); - throw LlamaException(LlamaErrorCode::MODEL_LOAD_FAILED, config.model_path); - } - - llama_model.reset(model); -} - -LlamaEngine::~LlamaEngine() { - llama_backend_free(); -} - -LlamaSession *LlamaEngine::session(const NativeSessionParams &config) { - return new LlamaSession(llama_model.get(), threads, config); -} diff --git a/sdk/src/main/cpp/engine.h b/sdk/src/main/cpp/engine.h deleted file mode 100644 index 8390a0e..0000000 --- a/sdk/src/main/cpp/engine.h +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#include -#include - -#include "llama-cpp.h" -#include "session.h" - -struct NativeEngineParams { - std::string model_path; - int threads; - bool use_mmap; - bool use_mlock; - // Optional progress callback; nullptr means no progress reporting - std::function progress_callback = nullptr; -}; - -class LlamaEngine { -private: - llama_model_ptr llama_model; - int threads; - -public: - LlamaEngine(const NativeEngineParams &config); - - ~LlamaEngine(); - - LlamaEngine(const LlamaEngine &) = delete; - LlamaEngine(LlamaEngine &&) = delete; - LlamaEngine &operator=(const LlamaEngine &) = delete; - LlamaEngine &operator=(LlamaEngine &&) = delete; - - LlamaSession *session(const NativeSessionParams &config); -}; diff --git a/sdk/src/main/cpp/engine/engine.cpp b/sdk/src/main/cpp/engine/engine.cpp new file mode 100644 index 0000000..8feeaa7 --- /dev/null +++ b/sdk/src/main/cpp/engine/engine.cpp @@ -0,0 +1,55 @@ +#include + +#include "engine.hpp" +#include "ggml_variant_chooser.hpp" + +#include "session/session.hpp" +#include "result/codes.hpp" +#include "ggml-backend.h" +#include "utils/log.hpp" + +namespace engine { + Engine::Engine(const NativeEngineParams &config) { + auto backend = resolve_best_ggml_backend(); + auto backend_result = ggml_backend_load(backend); + if (backend_result == nullptr) { + throw std::runtime_error("Failed to load GGML backend."); + } + + llama_backend_init(); + + auto params = llama_model_default_params(); + params.use_mmap = config.use_mmap; + params.use_mlock = config.use_mlock; + params.n_gpu_layers = 0; // CPU-only for now + + // Wire up progress callback if provided + if (config.progress_callback) { + // Stack-safe: std::function lives in config, which lives in the JNI frame. + // llama_model_load_from_file is synchronous — the callback is only called here. + params.progress_callback_user_data = const_cast *>(&config.progress_callback); + params.progress_callback = [](float progress, void *user_data) -> bool { + auto *cb = static_cast *>(user_data); + return (*cb)(progress); + }; + } + + auto *model = llama_model_load_from_file(config.model_path.c_str(), params); + + if (!model) { + llama_backend_free(); + throw std::runtime_error("Failed to load model."); + } + + llama_model.reset(model); + } + + Engine::~Engine() { + llama_backend_free(); + } + + session::Session *Engine::session(const session::NativeSessionParams &config) { + return new session::Session(llama_model.get(), config); + } +} diff --git a/sdk/src/main/cpp/engine/engine.hpp b/sdk/src/main/cpp/engine/engine.hpp new file mode 100644 index 0000000..f0e5045 --- /dev/null +++ b/sdk/src/main/cpp/engine/engine.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +#include "llama-cpp.h" +#include "session/session.hpp" + +namespace engine { + struct NativeEngineParams { + std::string model_path; + int threads; + bool use_mmap; + bool use_mlock; + // Optional progress callback; nullptr means no progress reporting + std::function progress_callback = nullptr; + }; + + class Engine { + private: + llama_model_ptr llama_model; + + public: + Engine(const NativeEngineParams &config); + + ~Engine(); + + Engine(const Engine &) = delete; + + Engine(Engine &&) = delete; + + Engine &operator=(const Engine &) = delete; + + Engine &operator=(Engine &&) = delete; + + session::Session *session(const session::NativeSessionParams &config); + }; +} diff --git a/sdk/src/main/cpp/engine/ggml_variant_chooser.hpp b/sdk/src/main/cpp/engine/ggml_variant_chooser.hpp new file mode 100644 index 0000000..3428160 --- /dev/null +++ b/sdk/src/main/cpp/engine/ggml_variant_chooser.hpp @@ -0,0 +1,84 @@ +#pragma once + +#include +#include +#include "utils/log.hpp" + +static inline bool has_bit(unsigned long value, unsigned long bit) { + return (value & bit) != 0; +} + +/** + * Probes the CPU for features and chooses the most optimized backend for GGML + * for fastest inference without encountering SIGILL + * + * @return Path to the best GGML backend + */ +const char *resolve_best_ggml_backend() { + const unsigned long hwcap = getauxval(AT_HWCAP); + const unsigned long hwcap2 = getauxval(AT_HWCAP2); + + bool fp16 = has_bit(hwcap, HWCAP_ASIMDHP); + bool dotprod = has_bit(hwcap, HWCAP_ASIMDDP); + bool sve = has_bit(hwcap, HWCAP_SVE); + + bool sve2 = has_bit(hwcap2, HWCAP2_SVE2); + bool i8mm = has_bit(hwcap2, HWCAP2_I8MM); + bool sme = has_bit(hwcap2, HWCAP2_SME); + + // Defensive normalization against bad kernel reports + if (!dotprod) { + fp16 = false; + i8mm = false; + sve2 = false; + sme = false; + } + + if (!fp16) { + i8mm = false; + sve2 = false; + sme = false; + } + + if (!i8mm) { + sve2 = false; + sme = false; + } + + if (!sve) { + sme = false; + } + + if (dotprod && fp16 && i8mm && sve && sve2 && sme) { + LOGI("DOTPROD + FP16 + I8MM + SVE + SVE2 + SME = libggml-cpu-android_armv9.2_2.so"); + return "libggml-cpu-android_armv9.2_2.so"; + } + + if (dotprod && fp16 && i8mm && sve && sme) { + LOGI("DOTPROD + FP16 + I8MM + SVE + SME = libggml-cpu-android_armv9.2_1.so"); + return "libggml-cpu-android_armv9.2_1.so"; + } + + if (dotprod && fp16 && i8mm && sve2) { + LOGI("DOTPROD + FP16 + I8MM + SVE2 = libggml-cpu-android_armv9.0_1.so"); + return "libggml-cpu-android_armv9.0_1.so"; + } + + if (dotprod && fp16 && i8mm) { + LOGI("DOTPROD + FP16 + I8MM = libggml-cpu-android_armv8.6_1.so"); + return "libggml-cpu-android_armv8.6_1.so"; + } + + if (dotprod && fp16) { + LOGI("DOTPROD + FP16 = libggml-cpu-android_armv8.2_2.so"); + return "libggml-cpu-android_armv8.2_2.so"; + } + + if (dotprod) { + LOGI("DOTPROD = libggml-cpu-android_armv8.2_1.so"); + return "libggml-cpu-android_armv8.2_1.so"; + } + + LOGI("Base = libggml-cpu-android_armv8.0_1.so"); + return "libggml-cpu-android_armv8.0_1.so"; +} diff --git a/sdk/src/main/cpp/external/llama.cpp b/sdk/src/main/cpp/external/llama.cpp index b6c83aa..42ebce3 160000 --- a/sdk/src/main/cpp/external/llama.cpp +++ b/sdk/src/main/cpp/external/llama.cpp @@ -1 +1 @@ -Subproject commit b6c83aad55a4ce17ec96fced7770cd1be8758193 +Subproject commit 42ebce3bebeab64fbe71b667d1bafd9960e83cbf diff --git a/sdk/src/main/cpp/utils/jni_config_reader.h b/sdk/src/main/cpp/jni/config_reader.hpp similarity index 100% rename from sdk/src/main/cpp/utils/jni_config_reader.h rename to sdk/src/main/cpp/jni/config_reader.hpp diff --git a/sdk/src/main/cpp/jni/llama_engine_jni.cpp b/sdk/src/main/cpp/jni/engine.cpp similarity index 81% rename from sdk/src/main/cpp/jni/llama_engine_jni.cpp rename to sdk/src/main/cpp/jni/engine.cpp index 9f841b4..0b2a1b0 100644 --- a/sdk/src/main/cpp/jni/llama_engine_jni.cpp +++ b/sdk/src/main/cpp/jni/engine.cpp @@ -1,8 +1,7 @@ #include -#include "utils/jni_config_reader.h" -#include "utils/jni_error_thrower.h" -#include "utils/llama_exception.h" -#include "engine.h" +#include "config_reader.hpp" +#include "result/codes.hpp" +#include "engine/engine.hpp" #include "llama.h" // ── Shared helper ───────────────────────────────────────────────────────────── @@ -12,14 +11,14 @@ namespace jni_refs { constexpr auto progress_listener_method_sig = "(F)Z"; } -static NativeEngineParams readEngineParams(JNIEnv *env, +static engine::NativeEngineParams readEngineParams(JNIEnv *env, jobject jConfig) { auto configReader = JniConfigReader(env, jConfig); - return NativeEngineParams{ + return engine::NativeEngineParams{ .model_path = configReader.getString("modelPath"), .threads = configReader.getInt("threads"), - .use_mmap = configReader.getBool("useMmap"), - .use_mlock = configReader.getBool("useMlock"), + .use_mmap = configReader.getBool("useMMap"), + .use_mlock = configReader.getBool("useMLock"), }; } @@ -30,10 +29,9 @@ JNIEXPORT jlong JNICALL Java_com_suhel_llamabro_sdk_internal_LlamaEngineImpl_00024Jni_create(JNIEnv *env, jclass, jobject jConfig) { try { - auto instance = new LlamaEngine(readEngineParams(env, jConfig)); + auto instance = new engine::Engine(readEngineParams(env, jConfig)); return reinterpret_cast(instance); - } catch (const LlamaException &ex) { - throwLlamaError(env, ex); + } catch (const std::exception &ex) { return 0L; } } @@ -62,10 +60,9 @@ Java_com_suhel_llamabro_sdk_internal_LlamaEngineImpl_00024Jni_createWithProgress }; try { - auto instance = new LlamaEngine(config); + auto instance = new engine::Engine(config); return reinterpret_cast(instance); - } catch (const LlamaException &ex) { - throwLlamaError(env, ex); + } catch (const std::exception &ex) { return 0L; } } @@ -74,5 +71,5 @@ extern "C" JNIEXPORT void JNICALL Java_com_suhel_llamabro_sdk_internal_LlamaEngineImpl_00024Jni_destroy(JNIEnv *, jclass, jlong jEnginePtr) { - delete reinterpret_cast(jEnginePtr); + delete reinterpret_cast(jEnginePtr); } diff --git a/sdk/src/main/cpp/jni/llama_session_jni.cpp b/sdk/src/main/cpp/jni/session.cpp similarity index 61% rename from sdk/src/main/cpp/jni/llama_session_jni.cpp rename to sdk/src/main/cpp/jni/session.cpp index b5383e7..ca3c33a 100644 --- a/sdk/src/main/cpp/jni/llama_session_jni.cpp +++ b/sdk/src/main/cpp/jni/session.cpp @@ -1,25 +1,33 @@ #include -#include "utils/jni_config_reader.h" -#include "utils/jni_error_thrower.h" -#include "utils/llama_exception.h" -#include "engine.h" +#include + +#include "config_reader.hpp" +#include "result/codes.hpp" +#include "engine/engine.hpp" +#include "session/session.hpp" namespace jni_refs { - constexpr auto token_generation_result_class = "com/suhel/llamabro/sdk/internal/LlamaSessionImpl$NativeTokenGenerationResult"; - constexpr auto token_generation_result_constructor_sig = "(Ljava/lang/String;Z)V"; + constexpr auto result_class = "com/suhel/llamabro/sdk/internal/LlamaSessionImpl$NativeTokenGenerationResult"; + constexpr auto result_field_token = "token"; + constexpr auto result_field_result = "resultCode"; + constexpr auto result_field_is_complete = "isComplete"; } -static jclass jTokenGenerationResultClass = nullptr; -static jmethodID jTokenGenerationResultConstructor = nullptr; +static jfieldID jResultFieldToken = nullptr; +static jfieldID jResultFieldResultCode = nullptr; +static jfieldID jResultFieldIsComplete = nullptr; static void cache_refs(JNIEnv *env) { - auto local = env->FindClass(jni_refs::token_generation_result_class); + auto class_ref = env->FindClass(jni_refs::result_class); + + jResultFieldToken = env->GetFieldID(class_ref, + jni_refs::result_field_token, "Ljava/lang/String;"); + jResultFieldResultCode = env->GetFieldID(class_ref, + jni_refs::result_field_result, "I"); + jResultFieldIsComplete = env->GetFieldID(class_ref, + jni_refs::result_field_is_complete, "Z"); - jTokenGenerationResultClass = reinterpret_cast(env->NewGlobalRef(local)); - jTokenGenerationResultConstructor = env->GetMethodID(jTokenGenerationResultClass, - "", - jni_refs::token_generation_result_constructor_sig); - env->DeleteLocalRef(local); + env->DeleteLocalRef(class_ref); } JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *) { @@ -40,10 +48,10 @@ Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_create(JNIEnv *en jclass, jlong jEnginePtr, jobject jParams) { - auto engine = reinterpret_cast(jEnginePtr); + auto engine = reinterpret_cast(jEnginePtr); auto configReader = JniConfigReader(env, jParams); - auto config = NativeSessionParams{ + auto config = session::NativeSessionParams{ .context_size = configReader.getInt("contextSize"), .overflow_strategy_id = configReader.getInt("overflowStrategyId"), .overflow_drop_tokens = configReader.getInt("overflowDropTokens"), @@ -63,8 +71,7 @@ Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_create(JNIEnv *en try { return reinterpret_cast(engine->session(config)); - } catch (const LlamaException &ex) { - throwLlamaError(env, ex); + } catch (const std::exception &ex) { return 0L; } } @@ -75,36 +82,26 @@ extern "C" JNIEXPORT void JNICALL Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_setSystemPrompt(JNIEnv *env, jclass, jlong jSessionPtr, - jstring jText, - jboolean jAddSpecial) { - auto session = reinterpret_cast(jSessionPtr); + jstring jText) { + auto session = reinterpret_cast(jSessionPtr); auto text = env->GetStringUTFChars(jText, nullptr); - auto textStr = std::string(text); + auto text_str = std::string(text); env->ReleaseStringUTFChars(jText, text); - try { - session->setSystemPrompt(textStr, jAddSpecial); - } catch (const LlamaException &ex) { - throwLlamaError(env, ex); - } + session->set_system_prompt(text_str); } extern "C" JNIEXPORT void JNICALL -Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_ingestPrompt(JNIEnv *env, jclass, - jlong jSessionPtr, - jstring jText, - jboolean jAddSpecial) { - auto session = reinterpret_cast(jSessionPtr); +Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_addUserPrompt(JNIEnv *env, jclass, + jlong jSessionPtr, + jstring jText) { + auto session = reinterpret_cast(jSessionPtr); auto text = env->GetStringUTFChars(jText, nullptr); - auto textStr = std::string(text); + auto text_str = std::string(text); env->ReleaseStringUTFChars(jText, text); - try { - session->ingestPrompt(textStr, jAddSpecial); - } catch (const LlamaException &ex) { - throwLlamaError(env, ex); - } + session->add_user_prompt(text_str); } // ── clear ───────────────────────────────────────────────────────────────────── @@ -113,13 +110,8 @@ extern "C" JNIEXPORT void JNICALL Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_clear(JNIEnv *env, jclass, jlong jSessionPtr) { - auto session = reinterpret_cast(jSessionPtr); - - try { - session->clear(); - } catch (const LlamaException &ex) { - throwLlamaError(env, ex); - } + auto session = reinterpret_cast(jSessionPtr); + session->clear(); } // ── abort ───────────────────────────────────────────────────────────────────── @@ -128,34 +120,32 @@ extern "C" JNIEXPORT void JNICALL Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_abort(JNIEnv *, jclass, jlong jSessionPtr) { - auto session = reinterpret_cast(jSessionPtr); + auto session = reinterpret_cast(jSessionPtr); session->abort(); } // ── generate ───────────────────────────────────────────────────────────────── extern "C" -JNIEXPORT jobject JNICALL +JNIEXPORT void JNICALL Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_generate(JNIEnv *env, jclass, - jlong jSessionPtr) { - auto session = reinterpret_cast(jSessionPtr); - - try { - auto gen = session->generate(); - auto token = gen.token; - - auto jToken = token.has_value() - ? env->NewString(reinterpret_cast(token.value().data()), - static_cast(token.value().size())) - : nullptr; - auto jIsComplete = static_cast(gen.is_complete); - - return env->NewObject(jTokenGenerationResultClass, jTokenGenerationResultConstructor, - jToken, jIsComplete); - } catch (const LlamaException &ex) { - throwLlamaError(env, ex); - return nullptr; - } + jlong jSessionPtr, + jobject jResultObj) { + auto session = reinterpret_cast(jSessionPtr); + auto gen = session->generate(); + auto token = gen.token; + auto result_code = gen.result_code; + + auto jToken = token + ? env->NewString(reinterpret_cast(token->data()), + static_cast(token->size())) + : nullptr; + auto jResultCode = static_cast(result_code); + auto jIsComplete = static_cast(gen.is_complete); + + env->SetObjectField(jResultObj, jResultFieldToken, jToken); + env->SetIntField(jResultObj, jResultFieldResultCode, jResultCode); + env->SetBooleanField(jResultObj, jResultFieldIsComplete, jIsComplete); } // ── destroy ─────────────────────────────────────────────────────────────────── @@ -164,6 +154,6 @@ extern "C" JNIEXPORT void JNICALL Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_destroy(JNIEnv *, jclass, jlong jSessionPtr) { - auto session = reinterpret_cast(jSessionPtr); + auto session = reinterpret_cast(jSessionPtr); delete session; } diff --git a/sdk/src/main/cpp/parsers/tag.hpp b/sdk/src/main/cpp/parsers/tag.hpp new file mode 100644 index 0000000..3abc984 --- /dev/null +++ b/sdk/src/main/cpp/parsers/tag.hpp @@ -0,0 +1,165 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace parsers { + + struct NormalContent { + std::string text; + }; + struct TagContent { + int tag_id; + std::string text; + }; + using EmitEvent = std::variant; + + class TagParser { + public: + void add(int id, std::string_view start, std::string_view end) { + assert(!feeding && "add() must not be called after feed() has begun"); + assert(!start.empty() && !end.empty()); + + tags.push_back({id, std::string(start), std::string(end)}); + + if (trigger_chars.find(start[0]) == std::string::npos) + trigger_chars += start[0]; + } + + bool enter_tag(int id) { + for (size_t i = 0; i < tags.size(); ++i) { + if (tags[i].id == id) { + active_tag_idx = i; + return true; + } + } + return false; + } + + std::vector feed(std::string_view token) { + feeding = true; + buffer += token; // Combine any leftover prefix with new input + std::string_view input = buffer; + std::vector events; + + // Helper to merge adjacent events of the same type + auto emit = [&](std::string_view text) { + if (text.empty()) return; + + if (active_tag_idx) { + int id = tags[*active_tag_idx].id; + if (!events.empty() + && std::holds_alternative(events.back()) + && std::get(events.back()).tag_id == id) { + std::get(events.back()).text.append(text); + return; + } + events.emplace_back(TagContent{id, std::string(text)}); + } else { + if (!events.empty() && std::holds_alternative(events.back())) { + std::get(events.back()).text.append(text); + return; + } + events.emplace_back(NormalContent{std::string(text)}); + } + }; + + while (!input.empty()) { + if (active_tag_idx) { + const std::string &end_tag = tags[*active_tag_idx].end; + size_t pos = input.find(end_tag[0]); + + if (pos == std::string_view::npos) { + emit(input); + input = {}; // Consume all + break; + } + + emit(input.substr(0, pos)); + input.remove_prefix(pos); + + if (input.size() >= end_tag.size() + && input.substr(0, end_tag.size()) == end_tag) { + active_tag_idx.reset(); + input.remove_prefix(end_tag.size()); + } else if (end_tag.compare(0, input.size(), input) == 0) { + break; // Partial match at the end of input; keep in buffer + } else { + emit(input.substr(0, 1)); // False alarm: emit the trigger char + input.remove_prefix(1); // and move on + } + } else { + size_t pos = input.find_first_of(trigger_chars); + + if (pos == std::string_view::npos) { + emit(input); + input = {}; + break; + } + + emit(input.substr(0, pos)); + input.remove_prefix(pos); + + bool matched = false; + bool partial = false; + + for (size_t i = 0; i < tags.size(); ++i) { + const std::string &start_tag = tags[i].start; + if (input.size() >= start_tag.size() && + input.substr(0, start_tag.size()) == start_tag) { + active_tag_idx = i; + input.remove_prefix(start_tag.size()); + matched = true; + break; + } else if (start_tag.compare(0, input.size(), input) == 0) { + partial = true; + } + } + + if (matched) continue; + if (partial) break; // Partial match; keep remainder in buffer + + // False alarm: emit trigger char and continue + emit(input.substr(0, 1)); + input.remove_prefix(1); + } + } + + buffer = std::string(input); // Retain any partial matches + return events; + } + + std::vector flush() { + std::vector events; + if (!buffer.empty()) { + if (active_tag_idx) { + events.emplace_back(TagContent{tags[*active_tag_idx].id, std::move(buffer)}); + } else { + events.emplace_back(NormalContent{std::move(buffer)}); + } + buffer.clear(); + } + active_tag_idx.reset(); + feeding = false; + return events; + } + + private: + struct Tag { + int id; + std::string start; + std::string end; + }; + + std::vector tags; + std::optional active_tag_idx; + std::string trigger_chars; + std::string buffer; + bool feeding = false; + }; + +} // namespace parsers \ No newline at end of file diff --git a/sdk/src/main/cpp/parsers/token.hpp b/sdk/src/main/cpp/parsers/token.hpp new file mode 100644 index 0000000..0cc7ad1 --- /dev/null +++ b/sdk/src/main/cpp/parsers/token.hpp @@ -0,0 +1,154 @@ +#pragma once + +#include +#include + +namespace parsers { + static constexpr unsigned char CONTINUATION_MASK = 0xC0; + static constexpr unsigned char CONTINUATION_VAL = 0x80; + + // Masks to identify leading bytes and their expected lengths + static constexpr unsigned char LEAD2_MASK = 0xE0; // 110xxxxx + static constexpr unsigned char LEAD2_VAL = 0xC0; + static constexpr unsigned char LEAD3_MASK = 0xF0; // 1110xxxx + static constexpr unsigned char LEAD3_VAL = 0xE0; + static constexpr unsigned char LEAD4_MASK = 0xF8; // 11110xxx + static constexpr unsigned char LEAD4_VAL = 0xF0; + + struct TokenParser { + public: + explicit TokenParser() { + token_buffer.reserve(32); + } + + std::optional parse(std::string_view text) { + if (token_buffer.empty()) { + if (count_incomplete_tail_bytes(text) == 0) { + return utf8_to_utf16(text); + } + } + + token_buffer.append(text); + + auto incomplete = count_incomplete_tail_bytes(token_buffer); + auto complete_len = token_buffer.size() - incomplete; + + if (complete_len > 0) { + auto result = utf8_to_utf16(std::string_view(token_buffer).substr(0, complete_len)); + token_buffer.erase(0, complete_len); + return result; + } + + return std::nullopt; + } + + void reset() { + token_buffer.clear(); + } + + private: + std::string token_buffer; + + static size_t count_incomplete_tail_bytes(std::string_view token) { + if (token.empty()) return 0; + + if ((token.back() & CONTINUATION_VAL) == 0) { + return 0; + } + + auto size = token.size(); + for (size_t i = 1; i <= 4 && i <= size; ++i) { + auto c = static_cast(token[size - i]); + + // If we found the Lead Byte (not 10xxxxxx) + if ((c & CONTINUATION_MASK) != CONTINUATION_VAL) { + size_t expected = 1; + if ((c & LEAD2_MASK) == LEAD2_VAL) expected = 2; + else if ((c & LEAD3_MASK) == LEAD3_VAL) expected = 3; + else if ((c & LEAD4_MASK) == LEAD4_VAL) expected = 4; + + return (i < expected) ? i : 0; + } + } + return 0; + } + + static std::u16string utf8_to_utf16(std::string_view utf8) { + if (utf8.empty()) return {}; + + std::u16string utf16; + // Optimization 1: Pre-allocate memory. + // UTF-16 will never have more code units than UTF-8 has bytes. + utf16.reserve(utf8.size()); + + auto cur = reinterpret_cast(utf8.data()); + auto end = cur + utf8.size(); + + while (cur < end) { + auto c = *cur; + + // Optimization 2: The ASCII Fast-Path + // Most LLM tokens are simple characters. This branch is highly predictable. + if (c < CONTINUATION_VAL) { + utf16.push_back(static_cast(c)); + cur++; + continue; + } + + // Multi-byte sequences + uint32_t cp = 0; + int len = 0; + + if ((c & LEAD2_MASK) == LEAD2_VAL) { + cp = c & 0x1F; + len = 2; + } else if ((c & LEAD3_MASK) == LEAD3_VAL) { + cp = c & 0x0F; + len = 3; + } else if ((c & LEAD4_MASK) == LEAD4_VAL) { + cp = c & 0x07; + len = 4; + } else { + // Invalid UTF-8 lead byte: use replacement character + utf16.push_back(u'\uFFFD'); + cur++; + continue; + } + + if (cur + len > end) { + utf16.push_back(u'\uFFFD'); + break; + } + + // Unroll the continuation byte checks for speed + bool valid = true; + for (int i = 1; i < len; ++i) { + if ((cur[i] & CONTINUATION_MASK) != CONTINUATION_VAL) { + valid = false; + break; + } + cp = (cp << 6) | (cur[i] & 0x3F); + } + + if (!valid) { + utf16.push_back(u'\uFFFD'); + cur++; + continue; + } + + // Optimization 3: Handle Surrogate Pairs for 4-byte UTF-8 + if (cp <= 0xFFFF) { + utf16.push_back(static_cast(cp)); + } else { + // Codepoint > 0xFFFF requires two 16-bit units (Surrogate Pair) + cp -= 0x10000; + utf16.push_back(static_cast(0xD800 + (cp >> 10))); + utf16.push_back(static_cast(0xDC00 + (cp & 0x3FF))); + } + cur += len; + } + + return utf16; + } + }; +} diff --git a/sdk/src/main/cpp/result/codes.hpp b/sdk/src/main/cpp/result/codes.hpp new file mode 100644 index 0000000..fe16299 --- /dev/null +++ b/sdk/src/main/cpp/result/codes.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include +#include + +enum class ResultCode : int { + OK = 0, + + // ── Engine ────────────────────────────────────────────────────────────── + MODEL_NOT_FOUND = 1, // model file path does not exist + MODEL_LOAD_FAILED = 2, // file exists but llama_model_load_from_file returned null + BACKEND_LOAD_FAILED = 3, // ggml_backend_load returned non-zero + CANCELLED = 4, // operation was explicitly aborted via abort() + + // ── Session ───────────────────────────────────────────────────────────── + CONTEXT_INIT_FAILED = 10, // llama_init_from_model returned null + CONTEXT_OVERFLOW = 11, // HALT strategy: context is full, cannot recover + DECODE_FAILED = 12, // llama_decode returned non-zero + + // ── Catch-all ──────────────────────────────────────────────────────────── + UNKNOWN = 99, +}; diff --git a/sdk/src/main/cpp/session.cpp b/sdk/src/main/cpp/session.cpp deleted file mode 100644 index e6a442e..0000000 --- a/sdk/src/main/cpp/session.cpp +++ /dev/null @@ -1,318 +0,0 @@ -#include "session.h" - -#include "utils/llama_utils.h" -#include "utils/utf8_utils.h" -#include "utils/llama_exception.h" -#include "utils/log.h" - -#include - -#include "llama.h" - -namespace constants { - constexpr int STRATEGY_ID_HALT = 0; - constexpr int STRATEGY_ID_CLEAR_HISTORY = 1; - constexpr int STRATEGY_ID_ROLLING_WINDOW = 2; -} - -LlamaSession::LlamaSession(llama_model *model, int threads, const NativeSessionParams &config) { - auto params = llama_context_default_params(); - params.n_ctx = config.context_size; - - // Clamp to training context to avoid OOM and RoPE degradation on mobile. - auto n_ctx_train = static_cast(llama_model_n_ctx_train(model)); - if (params.n_ctx > n_ctx_train) { - LOGW("Requested context size %u exceeds model training context %u; clamping.", - params.n_ctx, n_ctx_train); - params.n_ctx = n_ctx_train; - } - - params.n_threads = threads; - params.n_threads_batch = threads; - params.n_batch = config.batch_size; - params.n_ubatch = config.micro_batch_size; - params.n_seq_max = 1; - params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; - params.type_k = GGML_TYPE_Q8_0; // Saves 50% space with very little loss - params.type_v = GGML_TYPE_Q8_0; - - auto ctx = llama_init_from_model(model, params); - - if (!ctx) { - throw LlamaException(LlamaErrorCode::CONTEXT_INIT_FAILED); - } - - auto sampler_chain = llama_sampler_chain_init(llama_sampler_chain_default_params()); - - // Penalties first — modify logits before any truncation so filters - // operate on already-penalised probabilities (matches llama.cpp canonical order). - llama_sampler_chain_add(sampler_chain, - llama_sampler_init_penalties( - static_cast(config.context_size / 2), - config.rep_pen, - 0.0f, - config.presence_pen - ) - ); - - // Optional truncation samplers - if (config.top_k_enabled) { - llama_sampler_chain_add(sampler_chain, llama_sampler_init_top_k(config.top_k)); - } - - if (config.top_p_enabled) { - llama_sampler_chain_add(sampler_chain, llama_sampler_init_top_p(config.top_p, 1)); - } - - if (config.min_p_enabled) { - llama_sampler_chain_add(sampler_chain, llama_sampler_init_min_p(config.min_p, 1)); - } - - // Temperature and final selection - if (config.temp == 0.0f) { - llama_sampler_chain_add(sampler_chain, llama_sampler_init_greedy()); - } else { - llama_sampler_chain_add(sampler_chain, llama_sampler_init_temp(config.temp)); - llama_sampler_chain_add(sampler_chain, llama_sampler_init_dist(config.seed)); - } - - llama_context.reset(ctx); - llama_sampler_chain.reset(sampler_chain); - llama_batch = llama_batch_init(static_cast(llama_n_batch(ctx)), 0, 1); - - switch (config.overflow_strategy_id) { - case constants::STRATEGY_ID_HALT: - overflow_strategy = HALT; - break; - case constants::STRATEGY_ID_CLEAR_HISTORY: - overflow_strategy = CLEAR_HISTORY; - break; - case constants::STRATEGY_ID_ROLLING_WINDOW: - default: { - auto memory = llama_get_memory(ctx); - if (llama_memory_can_shift(memory)) { - overflow_strategy = ROLLING_WINDOW; - n_drop = config.overflow_drop_tokens; - } else { - // Model uses M-RoPE / I-MRoPE (e.g. Qwen3.5) — KV-cache position - // shifting is not supported, so fall back to clearing history instead. - overflow_strategy = CLEAR_HISTORY; - } - break; - } - } -} - -LlamaSession::~LlamaSession() { - llama_batch_free(llama_batch); -} - -bool LlamaSession::roll_kv_cache_if_needed(uint32_t required_tokens) { - auto ctx = llama_context.get(); - auto n_ctx = llama_n_ctx(ctx); - - if (n_past + required_tokens <= n_ctx) { - return true; - } - - switch (overflow_strategy) { - case HALT: { - return false; - } - - case CLEAR_HISTORY: { - roll_kv_cache_till_system_prompt(); - return true; - } - - case ROLLING_WINDOW: - default: { - return roll_kv_cache_to_accommodate(required_tokens); - } - } -} - -void LlamaSession::clear_kv_cache(int32_t start_pos, int32_t end_pos) { - auto ctx = llama_context.get(); - auto memory = llama_get_memory(ctx); - - llama_memory_seq_rm(memory, 0, start_pos, end_pos); -} - -void LlamaSession::roll_kv_cache_till_system_prompt() { - clear_kv_cache(n_keep, -1); - n_past = n_keep; -} - -bool LlamaSession::roll_kv_cache_to_accommodate(uint32_t required_tokens) { - auto ctx = llama_context.get(); - auto n_ctx = llama_n_ctx(ctx); - auto memory = llama_get_memory(ctx); - - while (n_past + required_tokens > n_ctx) { - auto safe_drop = std::min(n_drop, n_past - n_keep); - if (safe_drop <= 0) { - return false; // Cannot drop enough tokens without eating system prompt - } - - // Remove old tokens - clear_kv_cache(n_keep, n_keep + safe_drop); - // Shift remaining tokens left by 'safe_drop' amount - llama_memory_seq_add(memory, 0, n_keep + safe_drop, -1, -safe_drop); - - n_past -= safe_drop; - } - return true; -} - -void LlamaSession::ingest_prompt(const std::string &text, bool is_system_prompt, bool add_special) { - LOGI("ingest_prompt:\n%s", text.c_str()); - is_aborted.store(false); - - auto ctx = llama_context.get(); - auto model = llama_get_model(ctx); - auto vocab = llama_model_get_vocab(model); - - auto tokens = utils::tokenize(vocab, text, add_special, true); - if (tokens.empty()) return; - - uint32_t n_ctx = llama_n_ctx(ctx); - uint32_t n_batch_limit = llama_n_batch(ctx); - uint32_t max_usable_tokens = is_system_prompt ? n_ctx : (n_ctx - n_keep); - - if (tokens.size() > max_usable_tokens) { - tokens.erase(tokens.begin(), tokens.end() - max_usable_tokens); - } - - for (size_t i = 0; i < tokens.size(); i += n_batch_limit) { - if (is_aborted.load()) { - throw LlamaException(LlamaErrorCode::CANCELLED); - } - - auto chunk_size = std::min(n_batch_limit, static_cast(tokens.size() - i)); - - if (!roll_kv_cache_if_needed(chunk_size)) { - throw LlamaException(LlamaErrorCode::CONTEXT_OVERFLOW); - } - - utils::batch_clear(llama_batch); - for (int32_t j = 0; j < chunk_size; j++) { - utils::batch_add( - llama_batch, - tokens[i + j], - n_past + j, - {0}, - i + j == tokens.size() - 1 - ); - } - - auto decode_result = llama_decode(ctx, llama_batch); - if (decode_result != 0) { - throw LlamaException(LlamaErrorCode::DECODE_FAILED, - std::to_string(decode_result)); - } - - n_past += static_cast(chunk_size); - } - - if (is_system_prompt) { - n_keep = n_past; - } -} - -bool LlamaSession::is_token_buffer_valid() { - return !token_buffer.empty() && utils::llm_is_valid_utf8(token_buffer); -} - -std::u16string LlamaSession::get_and_clear_token_buffer() { - auto result = utils::llm_utf8_to_utf16_sanitized(token_buffer); - token_buffer.clear(); - return result; -} - -void LlamaSession::setSystemPrompt(const std::string &prompt, bool add_special) { - clear_kv_cache(0, -1); - - n_past = 0; - n_keep = 0; - token_buffer.clear(); - - ingest_prompt(prompt, true, add_special); -} - -void LlamaSession::ingestPrompt(const std::string &prompt, bool add_special) { - ingest_prompt(prompt, false, add_special); -} - -Generation LlamaSession::generate() { - auto ctx = llama_context.get(); - auto model = llama_get_model(ctx); - auto vocab = llama_model_get_vocab(model); - auto sampler = llama_sampler_chain.get(); - - is_aborted.store(false); - - while (true) { - if (is_aborted.load()) { - throw LlamaException(LlamaErrorCode::CANCELLED); - } - - auto new_token = llama_sampler_sample(sampler, ctx, -1); - - if (llama_vocab_is_eog(vocab, new_token)) { - // Decode the EOG token into the KV cache so the turn delimiter - // is present for subsequent turns (logits=false — no sampling needed). - if (roll_kv_cache_if_needed(1)) { - utils::batch_clear(llama_batch); - utils::batch_add(llama_batch, new_token, n_past, {0}, false); - if (llama_decode(ctx, llama_batch) == 0) { - n_past += 1; - } - } - - return Generation{ - .token = is_token_buffer_valid() - ? std::make_optional(get_and_clear_token_buffer()) - : std::nullopt, - .is_complete = true, - }; - } - - auto piece = utils::token_to_piece(vocab, new_token, true); - token_buffer.append(piece); - - if (!roll_kv_cache_if_needed(1)) { - token_buffer.clear(); - throw LlamaException(LlamaErrorCode::CONTEXT_OVERFLOW); - } - - utils::batch_clear(llama_batch); - utils::batch_add(llama_batch, new_token, n_past, {0}, true); - - auto decode_result = llama_decode(ctx, llama_batch); - if (decode_result != 0) { - throw LlamaException(LlamaErrorCode::DECODE_FAILED, - std::to_string(decode_result)); - } - - n_past += 1; - - if (is_token_buffer_valid()) { - return Generation{ - .token = is_token_buffer_valid() - ? std::make_optional(get_and_clear_token_buffer()) - : std::nullopt, - .is_complete = false, - }; - } - } -} - -void LlamaSession::clear() { - roll_kv_cache_till_system_prompt(); - token_buffer.clear(); -} - -void LlamaSession::abort() { - is_aborted.store(true); -} diff --git a/sdk/src/main/cpp/session.h b/sdk/src/main/cpp/session.h deleted file mode 100644 index 0084a07..0000000 --- a/sdk/src/main/cpp/session.h +++ /dev/null @@ -1,94 +0,0 @@ -#pragma once - -#include "llama-cpp.h" - -#include -#include - -enum OverflowStrategy { - HALT, - CLEAR_HISTORY, - ROLLING_WINDOW, -}; - -struct NativeSessionParams { - int context_size; - int overflow_strategy_id; - int overflow_drop_tokens; - - bool top_k_enabled; - int top_k; - bool top_p_enabled; - float top_p; - bool min_p_enabled; - float min_p; - - float rep_pen; - float presence_pen; - float temp; - - int seed; - - int batch_size; - int micro_batch_size; -}; - -struct Generation { - std::optional token; - bool is_complete; -}; - -#include - -class LlamaSession { -private: - llama_context_ptr llama_context; - llama_sampler_ptr llama_sampler_chain; - std::atomic is_aborted{false}; - - // Core Memory State - llama_batch llama_batch{0}; - std::string token_buffer; - - int32_t n_past = 0; - int32_t n_keep = 0; - OverflowStrategy overflow_strategy = ROLLING_WINDOW; - int32_t n_drop = 500; - - bool roll_kv_cache_if_needed(uint32_t required_tokens); - - void clear_kv_cache(int32_t start_pos, int32_t end_pos); - - void roll_kv_cache_till_system_prompt(); - - bool roll_kv_cache_to_accommodate(uint32_t required_tokens); - - void ingest_prompt(const std::string &text, bool is_system_prompt, bool add_special); - - bool is_token_buffer_valid(); - - std::u16string get_and_clear_token_buffer(); - -public: - LlamaSession(llama_model *model, int threads, const NativeSessionParams &config); - - ~LlamaSession(); - - LlamaSession(const LlamaSession &) = delete; - - LlamaSession(LlamaSession &&) = delete; - - LlamaSession &operator=(const LlamaSession &) = delete; - - LlamaSession &operator=(LlamaSession &&) = delete; - - void setSystemPrompt(const std::string &prompt, bool add_special); - - void ingestPrompt(const std::string &prompt, bool add_special); - - Generation generate(); - - void clear(); - - void abort(); -}; diff --git a/sdk/src/main/cpp/utils/llama_utils.h b/sdk/src/main/cpp/session/batch.hpp similarity index 68% rename from sdk/src/main/cpp/utils/llama_utils.h rename to sdk/src/main/cpp/session/batch.hpp index fd107ed..83d8c5b 100644 --- a/sdk/src/main/cpp/utils/llama_utils.h +++ b/sdk/src/main/cpp/session/batch.hpp @@ -4,28 +4,29 @@ #include #include "llama.h" -namespace utils { - // Taken from common.h/cpp +namespace session { + // Taken from common static void batch_clear(struct llama_batch &batch) { batch.n_tokens = 0; } - static bool batch_add( - struct llama_batch &batch, - llama_token id, - llama_pos pos, - const std::vector &seq_ids, - bool output) { + static bool batch_add(struct llama_batch &batch, + llama_token id, + llama_pos pos, + bool output) { if (!batch.seq_id[batch.n_tokens]) { return false; } batch.token[batch.n_tokens] = id; batch.pos[batch.n_tokens] = pos; - batch.n_seq_id[batch.n_tokens] = static_cast(seq_ids.size()); - for (size_t i = 0; i < seq_ids.size(); ++i) { - batch.seq_id[batch.n_tokens][i] = seq_ids[i]; - } +// Since we aren't batching yet +// batch.n_seq_id[batch.n_tokens] = static_cast(seq_ids.size()); +// for (size_t i = 0; i < seq_ids.size(); ++i) { +// batch.seq_id[batch.n_tokens][i] = seq_ids[i]; +// } + batch.n_seq_id[batch.n_tokens] = 1; + batch.seq_id[batch.n_tokens][0] = 0; batch.logits[batch.n_tokens] = output ? 1 : 0; batch.n_tokens++; @@ -33,11 +34,10 @@ namespace utils { return true; } - static std::vector tokenize( - const struct llama_vocab *vocab, - const std::string &text, - bool add_special, - bool parse_special) { + static std::vector tokenize(const struct llama_vocab *vocab, + std::string_view text, + bool add_special, + bool parse_special) { // upper limit for the number of tokens auto n_tokens = text.length() + 2 * add_special; std::vector result(n_tokens); @@ -62,8 +62,8 @@ namespace utils { return result; } - static std::string - token_to_piece(const struct llama_vocab *vocab, llama_token token, bool special) { + static std::string token_to_piece(const struct llama_vocab *vocab, + llama_token token, bool special) { std::string piece; piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' const int n_chars = llama_token_to_piece(vocab, token, &piece[0], diff --git a/sdk/src/main/cpp/session/session.cpp b/sdk/src/main/cpp/session/session.cpp new file mode 100644 index 0000000..81f5ab4 --- /dev/null +++ b/sdk/src/main/cpp/session/session.cpp @@ -0,0 +1,334 @@ +#include "session.hpp" + +#include "batch.hpp" +#include "result/codes.hpp" +#include "utils/log.hpp" + +#include +#include + +#include "llama.h" + +namespace session { + namespace constants { + constexpr int STRATEGY_ID_HALT = 0; + constexpr int STRATEGY_ID_CLEAR_HISTORY = 1; + constexpr int STRATEGY_ID_ROLLING_WINDOW = 2; + } + + static llama_sampler *create_sampler(const NativeSessionParams &config) { + auto sampler_chain = llama_sampler_chain_init(llama_sampler_chain_default_params()); + + // Penalties first — modify logits before any truncation so filters + // operate on already-penalised probabilities (matches llama.cpp canonical order). + llama_sampler_chain_add(sampler_chain, + llama_sampler_init_penalties( + static_cast(config.context_size / 2), + config.rep_pen, + 0.0f, + config.presence_pen + ) + ); + + // Optional truncation samplers + if (config.top_k_enabled) { + llama_sampler_chain_add(sampler_chain, + llama_sampler_init_top_k(config.top_k)); + } + + if (config.top_p_enabled) { + llama_sampler_chain_add(sampler_chain, + llama_sampler_init_top_p(config.top_p, 1)); + } + + if (config.min_p_enabled) { + llama_sampler_chain_add(sampler_chain, + llama_sampler_init_min_p(config.min_p, 1)); + } + + // Temperature and final selection + if (config.temp == 0.0f) { + llama_sampler_chain_add(sampler_chain, + llama_sampler_init_greedy()); + } else { + llama_sampler_chain_add(sampler_chain, + llama_sampler_init_temp(config.temp)); + llama_sampler_chain_add(sampler_chain, + llama_sampler_init_dist(config.seed)); + } + + return sampler_chain; + } + + static llama_context_params get_context_init_params(const llama_model *model, + const NativeSessionParams &config) { + auto params = llama_context_default_params(); + params.n_ctx = config.context_size; + + // Clamp to training context to avoid OOM and RoPE degradation on mobile. + auto n_ctx_train = static_cast(llama_model_n_ctx_train(model)); + if (params.n_ctx > n_ctx_train) { + LOGW("Requested context size %u exceeds model training context %u; clamping.", + params.n_ctx, n_ctx_train); + params.n_ctx = n_ctx_train; + } + + params.n_threads = config.threads; + params.n_threads_batch = config.threads; + params.n_batch = config.batch_size; + params.n_ubatch = config.micro_batch_size; + params.n_seq_max = 1; + params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; + params.type_k = GGML_TYPE_Q8_0; // Saves 50% space with very little loss + params.type_v = GGML_TYPE_Q8_0; + + return params; + } + + static OverflowStrategy decide_overflow_strategy(const llama_context *ctx, + const NativeSessionParams &config) { + switch (config.overflow_strategy_id) { + case constants::STRATEGY_ID_HALT: + return HALT; + case constants::STRATEGY_ID_CLEAR_HISTORY: + return CLEAR_HISTORY; + case constants::STRATEGY_ID_ROLLING_WINDOW: + default: + if (llama_memory_can_shift(llama_get_memory(ctx))) { + return ROLLING_WINDOW; + } else { + return CLEAR_HISTORY; + } + } + } + + Session::Session(llama_model *model, + const NativeSessionParams &config) { + auto ctx = llama_init_from_model(model, get_context_init_params(model, config)); + + if (!ctx) { + throw std::runtime_error("Failed to initialize llama context."); + } + + auto system_info = llama_print_system_info(); + LOGI("Initialized llama context with system info:\n%s", system_info); + + llama_context.reset(ctx); + llama_sampler_chain.reset(create_sampler(config)); + llama_batch = llama_batch_init(static_cast(llama_n_batch(ctx)), 0, 1); + + overflow_strategy = decide_overflow_strategy(ctx, config); + n_drop = config.overflow_drop_tokens; + } + + Session::~Session() { + llama_batch_free(llama_batch); + } + + ResultCode Session::set_system_prompt(std::string_view prompt) { + return ingest_prompt(prompt, true); + } + + ResultCode Session::add_user_prompt(std::string_view prompt) { + return ingest_prompt(prompt, false); + } + + Generation Session::generate() { + auto ctx = llama_context.get(); + auto model = llama_get_model(ctx); + auto vocab = llama_model_get_vocab(model); + auto sampler = llama_sampler_chain.get(); + auto result = ResultCode::OK; + + is_aborted.store(false); + + while (true) { + if (is_aborted.load()) { + result = ResultCode::CANCELLED; + break; + } + + auto new_token = llama_sampler_sample(sampler, ctx, -1); + auto is_end_token = llama_vocab_is_eog(vocab, new_token); + + if (!roll_kv_cache_if_needed(1)) { + result = ResultCode::CONTEXT_OVERFLOW; + break; + } + + batch_clear(llama_batch); + batch_add(llama_batch, new_token, n_past, !is_end_token); + + result = decode_current_batch(); + if (result == ResultCode::OK) { + n_past += 1; + } else { + break; + } + + if (is_end_token) { + token_parser.reset(); + break; + } + + auto piece = token_to_piece(vocab, new_token, true); + auto token = token_parser.parse(piece); + + if (token.has_value()) { + return { + .token = token, + .result_code = ResultCode::OK, + .is_complete = false, + }; + } + } + + return { + .token = std::nullopt, + .result_code = result, + .is_complete = true, + }; + } + + void Session::clear() { + roll_kv_cache_till_system_prompt(); + } + + void Session::abort() { + is_aborted.store(true); + } + + + // -------- private --------------- + + ResultCode Session::ingest_prompt(std::string_view text, bool reset_sequence) { + is_aborted.store(false); + + if (reset_sequence) { + clear_kv_cache(0, -1); + n_keep = 0; + } + + auto ctx = llama_context.get(); + auto model = llama_get_model(ctx); + auto vocab = llama_model_get_vocab(model); + auto n_ctx = llama_n_ctx(ctx); + auto n_batch_limit = llama_n_batch(ctx); + auto should_add_bos = llama_vocab_get_add_bos(vocab); + + auto tokens = tokenize(vocab, text, should_add_bos, true); + + if (tokens.empty()) { + return ResultCode::OK; + } + + if (tokens.size() > n_ctx) { + return ResultCode::CONTEXT_OVERFLOW; + } + + for (size_t i = 0; i < tokens.size(); i += n_batch_limit) { + if (is_aborted.load()) { + return ResultCode::CANCELLED; + } + + auto chunk_size = std::min(n_batch_limit, static_cast(tokens.size() - i)); + if (!roll_kv_cache_if_needed(chunk_size)) { + return ResultCode::CONTEXT_OVERFLOW; + } + + batch_clear(llama_batch); + for (uint32_t j = 0; j < chunk_size; j++) { + auto token_pos = static_cast(n_past + j); + auto is_last_token = i + j == tokens.size() - 1; + batch_add(llama_batch, tokens[i + j], token_pos, is_last_token); + } + + auto decode_result = decode_current_batch(); + if (decode_result != ResultCode::OK) { + return decode_result; + } + + n_past += static_cast(chunk_size); + } + + if (reset_sequence) { + n_keep = n_past; + } + + return ResultCode::OK; + } + + ResultCode Session::decode_current_batch() { + auto ctx = llama_context.get(); + auto decode_result = llama_decode(ctx, llama_batch); + + if (decode_result == 0) { + return ResultCode::OK; + } + + if (decode_result == 1) { + return ResultCode::CONTEXT_OVERFLOW; + } + + if (decode_result == 2 || decode_result < -1) { + // Partial batches might remain in memory, need to clear till last point + clear_kv_cache(n_past, -1); + } + + return ResultCode::DECODE_FAILED; + } + + void Session::clear_kv_cache(int32_t start_pos, int32_t end_pos) { + auto ctx = llama_context.get(); + auto memory = llama_get_memory(ctx); + + llama_memory_seq_rm(memory, 0, start_pos, end_pos); + } + + bool Session::roll_kv_cache_if_needed(uint32_t required_tokens) { + auto ctx = llama_context.get(); + auto n_ctx = llama_n_ctx(ctx); + + if (n_past + required_tokens <= n_ctx) { + return true; + } + + switch (overflow_strategy) { + case HALT: + return false; + + case CLEAR_HISTORY: + roll_kv_cache_till_system_prompt(); + return true; + + case ROLLING_WINDOW: + default: + return roll_kv_cache_to_accommodate(required_tokens); + } + } + + void Session::roll_kv_cache_till_system_prompt() { + clear_kv_cache(n_keep, -1); + n_past = n_keep; + } + + bool Session::roll_kv_cache_to_accommodate(uint32_t required_tokens) { + auto ctx = llama_context.get(); + auto n_ctx = llama_n_ctx(ctx); + auto memory = llama_get_memory(ctx); + + while (n_past + required_tokens > n_ctx) { + auto safe_drop = std::min(n_drop, n_past - n_keep); + if (safe_drop <= 0) { + return false; // Cannot drop enough tokens without eating system prompt + } + + clear_kv_cache(n_keep, n_keep + safe_drop); + llama_memory_seq_add(memory, 0, n_keep + safe_drop, -1, -safe_drop); + + n_past -= safe_drop; + } + + return true; + } +} diff --git a/sdk/src/main/cpp/session/session.hpp b/sdk/src/main/cpp/session/session.hpp new file mode 100644 index 0000000..64f9d32 --- /dev/null +++ b/sdk/src/main/cpp/session/session.hpp @@ -0,0 +1,99 @@ +#pragma once + +#include "llama-cpp.h" + +#include "parsers/tag.hpp" +#include "parsers/token.hpp" +#include "result/codes.hpp" + +#include +#include +#include + +namespace session { + enum OverflowStrategy { + HALT, + CLEAR_HISTORY, + ROLLING_WINDOW, + }; + + struct NativeSessionParams { + int context_size; + int threads; + int overflow_strategy_id; + int overflow_drop_tokens; + + bool top_k_enabled; + int top_k; + bool top_p_enabled; + float top_p; + bool min_p_enabled; + float min_p; + + float rep_pen; + float presence_pen; + float temp; + + int seed; + + int batch_size; + int micro_batch_size; + }; + + struct Generation { + std::optional token; + ResultCode result_code; + bool is_complete; + }; + + class Session { + public: + Session(llama_model *model, const NativeSessionParams &config); + + ~Session(); + + Session(const Session &) = delete; + + Session(Session &&) = delete; + + Session &operator=(const Session &) = delete; + + Session &operator=(Session &&) = delete; + + ResultCode add_user_prompt(std::string_view prompt); + + ResultCode set_system_prompt(std::string_view prompt); + + Generation generate(); + + void clear(); + + void abort(); + + private: + llama_context_ptr llama_context; + llama_sampler_ptr llama_sampler_chain; + std::atomic is_aborted{false}; + + // Core Memory State + llama_batch llama_batch{0}; + parsers::TokenParser token_parser; + + int32_t n_past = 0; + int32_t n_keep = 0; + int32_t n_drop = 500; + OverflowStrategy overflow_strategy = ROLLING_WINDOW; + + ResultCode ingest_prompt(std::string_view prompt, bool reset_sequence); + + ResultCode decode_current_batch(); + + void clear_kv_cache(int32_t start_pos, int32_t end_pos); + + bool roll_kv_cache_if_needed(uint32_t required_tokens); + + void roll_kv_cache_till_system_prompt(); + + bool roll_kv_cache_to_accommodate(uint32_t required_tokens); + }; +} diff --git a/sdk/src/main/cpp/utils/error_codes.h b/sdk/src/main/cpp/utils/error_codes.h deleted file mode 100644 index 0f6418b..0000000 --- a/sdk/src/main/cpp/utils/error_codes.h +++ /dev/null @@ -1,25 +0,0 @@ -#pragma once - -/** - * Canonical error codes for all failure modes in the llama-bro native layer. - * These values are stringified and passed through JNI as Java RuntimeExceptions, - * where the Kotlin internal layer maps them to typed LlamaError subclasses. - * - * IMPORTANT: The integer values of these enumerators are part of the ABI between - * C++ and Kotlin. Do NOT reorder or renumber — only append new entries. - */ -enum class LlamaErrorCode : int { - // ── Engine ────────────────────────────────────────────────────────────── - MODEL_NOT_FOUND = 1, // model file path does not exist - MODEL_LOAD_FAILED = 2, // file exists but llama_model_load_from_file returned null - BACKEND_LOAD_FAILED = 3, // ggml_backend_load returned non-zero - CANCELLED = 4, // operation was explicitly aborted via abort() - - // ── Session ───────────────────────────────────────────────────────────── - CONTEXT_INIT_FAILED = 10, // llama_init_from_model returned null - CONTEXT_OVERFLOW = 11, // HALT strategy: context is full, cannot recover - DECODE_FAILED = 12, // llama_decode returned non-zero - - // ── Catch-all ──────────────────────────────────────────────────────────── - UNKNOWN = 99, -}; diff --git a/sdk/src/main/cpp/utils/ggml_variant_chooser.h b/sdk/src/main/cpp/utils/ggml_variant_chooser.h deleted file mode 100644 index 54e64b8..0000000 --- a/sdk/src/main/cpp/utils/ggml_variant_chooser.h +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#include -#include - -const char *resolve_best_ggml_backend() { - unsigned long hwcap = getauxval(AT_HWCAP); - unsigned long hwcap2 = getauxval(AT_HWCAP2); - - bool has_fp16 = hwcap & HWCAP_ASIMDHP; - bool has_dotprod = hwcap & HWCAP_ASIMDDP; - bool has_sve = hwcap & HWCAP_SVE; - bool has_sve2 = hwcap2 & HWCAP2_SVE2; - bool has_i8mm = hwcap2 & HWCAP2_I8MM; - - // Cascade down from highest architectural requirements to lowest - if (has_sve && has_sve2 && has_i8mm && has_fp16 && has_dotprod) { - return "libggml-cpu-android_armv9.0_1.so"; - } - - if (has_i8mm && has_fp16 && has_dotprod) { - return "libggml-cpu-android_armv8.6_1.so"; - } - - if (has_fp16 && has_dotprod) { - return "libggml-cpu-android_armv8.2_2.so"; - } - - if (has_dotprod) { - return "libggml-cpu-android_armv8.2_1.so"; - } - - return "libggml-cpu-android_armv8.0_1.so"; -} diff --git a/sdk/src/main/cpp/utils/jni_error_thrower.h b/sdk/src/main/cpp/utils/jni_error_thrower.h deleted file mode 100644 index 6406094..0000000 --- a/sdk/src/main/cpp/utils/jni_error_thrower.h +++ /dev/null @@ -1,32 +0,0 @@ -#pragma once - -#include -#include -#include "error_codes.h" -#include "llama_exception.h" - -/** - * The single, exclusive point where Java/JNI exceptions are thrown from native code. - * - * Message format sent to Kotlin: ":" - * The Kotlin NativeErrorMapper parses this into a typed LlamaError subclass. - * - * ── Rules ──────────────────────────────────────────────────────────────────── - * 1. Native code (engine.cpp, session.cpp) MUST throw LlamaException — never call this directly. - * 2. JNI bridge code MUST catch LlamaException and call throwLlamaError(env, ex) — never call - * env->ThrowNew() directly. - * 3. Always return a zero/null sentinel from the JNI function immediately after calling this. - */ -inline void throwLlamaError(JNIEnv *env, LlamaErrorCode code, const char *detail = "") { - std::string message = std::to_string(static_cast(code)) + ":" + detail; - jclass clazz = env->FindClass("java/lang/RuntimeException"); - if (clazz != nullptr) { - env->ThrowNew(clazz, message.c_str()); - env->DeleteLocalRef(clazz); - } -} - -/** Convenience overload: converts a LlamaException directly. One-liner at every catch site. */ -inline void throwLlamaError(JNIEnv *env, const LlamaException &ex) { - throwLlamaError(env, ex.code, ex.what()); -} diff --git a/sdk/src/main/cpp/utils/llama_exception.h b/sdk/src/main/cpp/utils/llama_exception.h deleted file mode 100644 index e9d150f..0000000 --- a/sdk/src/main/cpp/utils/llama_exception.h +++ /dev/null @@ -1,25 +0,0 @@ -#pragma once - -#include -#include -#include "error_codes.h" - -/** - * The canonical exception type thrown by all llama-bro native code. - * - * Carries a typed [LlamaErrorCode] alongside a human-readable detail string. - * The JNI layer catches this and converts it to a Java exception via throwLlamaError(). - * - * Usage in native code: - * throw LlamaException(LlamaErrorCode::MODEL_LOAD_FAILED, config.model_path); - * - * Never throw std::runtime_error or any other exception type from native code. - * Never call env->ThrowNew() from native code — that is the JNI layer's job. - */ -class LlamaException : public std::runtime_error { -public: - const LlamaErrorCode code; - - LlamaException(LlamaErrorCode code, const std::string &detail = "") - : std::runtime_error(detail), code(code) {} -}; diff --git a/sdk/src/main/cpp/utils/log.h b/sdk/src/main/cpp/utils/log.hpp similarity index 79% rename from sdk/src/main/cpp/utils/log.h rename to sdk/src/main/cpp/utils/log.hpp index 7bb0b5f..f28896a 100644 --- a/sdk/src/main/cpp/utils/log.h +++ b/sdk/src/main/cpp/utils/log.hpp @@ -7,4 +7,5 @@ #define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) #define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__) #define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) +#define LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__) diff --git a/sdk/src/main/cpp/utils/overloaded.hpp b/sdk/src/main/cpp/utils/overloaded.hpp new file mode 100644 index 0000000..3c8df34 --- /dev/null +++ b/sdk/src/main/cpp/utils/overloaded.hpp @@ -0,0 +1,4 @@ +#pragma once + +template struct overloaded : Ts... { using Ts::operator()...; }; +template overloaded(Ts...) -> overloaded; diff --git a/sdk/src/main/cpp/utils/utf8_utils.h b/sdk/src/main/cpp/utils/utf8_utils.h deleted file mode 100644 index d3bc5d7..0000000 --- a/sdk/src/main/cpp/utils/utf8_utils.h +++ /dev/null @@ -1,95 +0,0 @@ -#pragma once - -#include - -namespace utils { - static bool llm_is_valid_utf8(const std::string &str) { - size_t i = 0; - while (i < str.length()) { - auto c = static_cast(str[i]); - int len = 0; - if (c < 0x80) { - len = 1; - } else if ((c & 0xE0) == 0xC0) { - len = 2; - } else if ((c & 0xF0) == 0xE0) { - len = 3; - } else if ((c & 0xF8) == 0xF0) { - len = 4; - } else { - return false; - } - - if (i + len > str.length()) { - return false; - } - for (int j = 1; j < len; ++j) { - if ((static_cast(str[i + j]) & 0xC0) != 0x80) { - return false; - } - } - i += len; - } - return true; - } - - static std::u16string llm_utf8_to_utf16_sanitized(const std::string &utf8) { - std::u16string out; - size_t i = 0; - - while (i < utf8.size()) { - uint32_t codepoint = 0xFFFD; - auto c = static_cast(utf8[i]); - size_t remaining = utf8.size() - i; - - if (c < 0x80) { - codepoint = c; - i += 1; - } else if ((c & 0xE0) == 0xC0 && remaining >= 2) { - auto c1 = static_cast(utf8[i + 1]); - if ((c1 & 0xC0) == 0x80) { - codepoint = ((c & 0x1F) << 6) | (c1 & 0x3F); - if (codepoint < 0x80) { codepoint = 0xFFFD; } - i += 2; - } else { i += 1; } - } else if ((c & 0xF0) == 0xE0 && remaining >= 3) { - auto c1 = static_cast(utf8[i + 1]); - auto c2 = static_cast(utf8[i + 2]); - if ((c1 & 0xC0) == 0x80 && (c2 & 0xC0) == 0x80) { - codepoint = ((c & 0x0F) << 12) | ((c1 & 0x3F) << 6) | (c2 & 0x3F); - if (codepoint < 0x800 || (codepoint >= 0xD800 && codepoint <= 0xDFFF)) { - codepoint = 0xFFFD; - } - i += 3; - } else { i += 1; } - } else if ((c & 0xF8) == 0xF0 && remaining >= 4) { - auto c1 = static_cast(utf8[i + 1]); - auto c2 = static_cast(utf8[i + 2]); - auto c3 = static_cast(utf8[i + 3]); - if ((c1 & 0xC0) == 0x80 && (c2 & 0xC0) == 0x80 && (c3 & 0xC0) == 0x80) { - codepoint = ((c & 0x07) << 18) | ((c1 & 0x3F) << 12) | ((c2 & 0x3F) << 6) | (c3 & 0x3F); - if (codepoint < 0x10000 || codepoint > 0x10FFFF) { - codepoint = 0xFFFD; - } - i += 4; - } else { i += 1; } - } else { i += 1; } - - if (codepoint == 0xFFFD) { - out.push_back(u'\uFFFD'); - continue; - } - - if (codepoint <= 0xFFFF) { - out.push_back(static_cast(codepoint)); - } else { - codepoint -= 0x10000; - auto high = static_cast(0xD800 + (codepoint >> 10)); - auto low = static_cast(0xDC00 + (codepoint & 0x3FF)); - out.push_back(high); - out.push_back(low); - } - } - return out; - } -} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/LlamaChatSession.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/LlamaChatSession.kt deleted file mode 100644 index 2d63f5f..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/LlamaChatSession.kt +++ /dev/null @@ -1,72 +0,0 @@ -package com.suhel.llamabro.sdk - -import com.suhel.llamabro.sdk.model.Completion -import com.suhel.llamabro.sdk.model.Message -import kotlinx.coroutines.flow.Flow - -/** - * High-level conversational API built on top of [LlamaSession]. - * - * This interface abstracts away the low-level prompt engineering and token - * management. It handles: - * - Message formatting using the [com.suhel.llamabro.sdk.model.PromptFormat] from the engine. - * - Assistant turn boundary management. - * - Real-time parsing of "thinking" blocks (e.g., `...`). - * - Accumulating token streams into structured [Completion] snapshots. - * - * ### Usage - * 1. Initialize with a system prompt. - * 2. Call [completion] with a user message to start generation. - * 3. Collect the resulting flow to receive updates. - * - * ### Thread Safety - * **Instances are not thread-safe.** Generation must be collected from a single - * coroutine at a time. Do not call [reset] or [loadHistory] while a generation - * flow is active. - */ -interface LlamaChatSession { - val supportsThinking: Boolean - /** - * Sends a message to the model and returns a reactive [Completion] flow. - * - * The flow emits progressively accumulated snapshots. Each emission includes - * the latest thinking text, content text, and eventually performance metrics - * like tokens-per-second once generation finishes. - * - * The flow **never throws**. All error conditions including fatal native errors - * are surfaced as a terminal [Completion] with [Completion.error] set. - * - * If the collector's coroutine is cancelled, the underlying native generation - * is automatically aborted. - * - * @param prompt The user's input text. - * @param enableThinking Whether to enable "thinking" mode. - * @param maxThinkingTokens Maximum number of tokens the model may spend thinking. - * When the limit is reached the closing tag is injected into the context, - * forcing the model to begin its response. `null` means no limit. - * @return A flow of [Completion] updates. - */ - fun completion( - prompt: String, - enableThinking: Boolean = false, - maxThinkingTokens: Int? = null, - ): Flow - - /** - * Clears the current conversation history while retaining the system prompt. - * - * This is useful for starting a new topic within the same session. - */ - suspend fun reset() - - /** - * Loads a sequence of historical messages into the session. - * - * This is used to "restore" a conversation state from a database or cache. - * It performs prompt ingestion (pre-fill) for each message but does not - * trigger generation. - * - * @param messages A list of [Message.User] and [Message.Assistant] messages. - */ - suspend fun loadHistory(messages: List) -} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/LlamaSession.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/LlamaSession.kt deleted file mode 100644 index 9cb67a2..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/LlamaSession.kt +++ /dev/null @@ -1,93 +0,0 @@ -package com.suhel.llamabro.sdk - -import com.suhel.llamabro.sdk.model.ResourceState -import com.suhel.llamabro.sdk.model.ModelConfig -import com.suhel.llamabro.sdk.model.TokenGenerationResult -import kotlinx.coroutines.flow.Flow - -/** - * Low-level inference session backed by a llama.cpp context. - * - * A session provides direct access to the KV cache and token generation loop. - * It is suitable for applications requiring fine-grained control over prompt - * ingestion and token sampling. - * - * For a higher-level, conversational API that handles chat templates and - * reasoning blocks, use [createChatSession]. - * - * ### Thread Safety - * **Instances are not thread-safe.** A session must be accessed from a single - * coroutine at a time. All suspending methods are safe to call from any - * dispatcher (they internally switch to [kotlinx.coroutines.Dispatchers.IO]). - * - * ### Lifecycle - * Sessions are bound to the parent [LlamaEngine]. Always call [close] when finished - * to release the native context memory. - */ -interface LlamaSession : AutoCloseable { - /** The configuration used to load the parent engine. */ - val modelConfig: ModelConfig - - /** - * Sets the system prompt for the session. - * - * This prompt is usually pinned at the start of the context and is preserved - * even during history clearing or rolling window overflows. - * - * @param text Raw text to add to the context. - * @param addSpecial If true, prepends the model's default BOS token. - */ - suspend fun setSystemPrompt(text: String, addSpecial: Boolean = true) - - /** - * Ingests raw text into the KV cache. - * - * This method blocks until the text is fully processed (pre-filled). - * It is cancellable; if the coroutine is cancelled, the native pre-fill - * loop will be interrupted. - * - * @param prompt Raw text to add to the context. - * @param addSpecial If true, prepends the model's default BOS token. - * @throws LlamaError.ContextOverflow if the context is full and cannot be recovered. - */ - suspend fun ingestPrompt(prompt: String, addSpecial: Boolean = false) - - /** - * Samples the next token from the model based on the current context. - * - * Call this in a loop to generate a complete response. - * - * @return The generated token as a String, or `null` if the model emits an - * End-of-Generation (EOG) token. - * @throws LlamaError.DecodeFailed if the native sampling loop fails. - */ - suspend fun generate(): TokenGenerationResult - - /** - * Clears the conversation history from the KV cache. - * - * The system prompt (if set via [setSystemPrompt]) is preserved. - */ - suspend fun clear() - - /** - * Asynchronously signals the native engine to stop any active computation. - * - * Use this to immediately halt a long-running [ingestPrompt] or [generate] call - * from another thread or UI action. - */ - fun abort() - - /** - * Creates a high-level chat session on top of this low-level session. - * - * @param systemPrompt The instruction defining the assistant's persona. - * @return A ready-to-use [LlamaChatSession]. - */ - suspend fun createChatSession(systemPrompt: String): LlamaChatSession - - /** - * Creates a high-level chat session asynchronously. - */ - fun createChatSessionFlow(systemPrompt: String): Flow> -} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/ProgressListener.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/ProgressListener.kt new file mode 100644 index 0000000..f76235b --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/ProgressListener.kt @@ -0,0 +1,5 @@ +package com.suhel.llamabro.sdk + +internal interface ProgressListener { + fun onProgress(progress: Float): Boolean +} \ No newline at end of file diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/chat/LlamaChatSession.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/chat/LlamaChatSession.kt new file mode 100644 index 0000000..a0e0d20 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/chat/LlamaChatSession.kt @@ -0,0 +1,15 @@ +package com.suhel.llamabro.sdk.chat + +import com.suhel.llamabro.sdk.models.ChatEvent +import com.suhel.llamabro.sdk.models.CompletionSnapshot +import com.suhel.llamabro.sdk.toolcall.ToolDefinition +import kotlinx.coroutines.flow.Flow + +/** + * A perfectly encapsulated session representing an active stateful conversation with a model. + */ +interface LlamaChatSession { + suspend fun initialize(tools: List = emptyList()) + suspend fun feedHistory(history: List) + fun completion(message: ChatEvent.UserEvent): Flow +} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/chat/internal/LlamaChatSessionImpl.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/chat/internal/LlamaChatSessionImpl.kt new file mode 100644 index 0000000..ecfb9b2 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/chat/internal/LlamaChatSessionImpl.kt @@ -0,0 +1,159 @@ +package com.suhel.llamabro.sdk.chat.internal + +import com.suhel.llamabro.sdk.chat.LlamaChatSession +import com.suhel.llamabro.sdk.engine.LlamaSession +import com.suhel.llamabro.sdk.format.PromptDecorator +import com.suhel.llamabro.sdk.format.PromptFormatter +import com.suhel.llamabro.sdk.format.ThinkingDecorator +import com.suhel.llamabro.sdk.format.ToolCallDecorator +import com.suhel.llamabro.sdk.models.ChatEvent +import com.suhel.llamabro.sdk.models.CompletionSnapshot +import com.suhel.llamabro.sdk.chat.pipeline.ThinkingMarker +import com.suhel.llamabro.sdk.chat.pipeline.ToolCallMarker +import com.suhel.llamabro.sdk.chat.pipeline.SemanticChunk +import com.suhel.llamabro.sdk.chat.pipeline.lexTags +import com.suhel.llamabro.sdk.chat.pipeline.semanticChunks +import com.suhel.llamabro.sdk.toolcall.ToolCall +import com.suhel.llamabro.sdk.toolcall.ToolDefinition +import com.suhel.llamabro.sdk.toolcall.ToolResult +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow + +internal class LlamaChatSessionImpl( + private val session: LlamaSession, + private val systemPrompt: String, + private val toolCaller: (suspend (List) -> List)? = null +) : LlamaChatSession { + + private var formatter: PromptFormatter = PromptFormatter(session.modelDefinition) + + override suspend fun initialize(tools: List) { + val modelDef = session.modelDefinition + val decorators = mutableListOf() + + // Setup Decorators for injection output mappings + if (modelDef.features.any { it is ThinkingMarker }) { + decorators.add(ThinkingDecorator(modelDef.features.filterIsInstance().first())) + } + if (modelDef.toolCall != null) { + decorators.add(ToolCallDecorator(modelDef.toolCall, tools)) + } + + formatter = PromptFormatter(modelDef, decorators) + + val system = ChatEvent.SystemEvent( + content = systemPrompt, + tools = tools, + ) + + session.setPrefixedPrompt(formatter.formatTurn(system)) + } + + override suspend fun feedHistory(history: List) { + history.forEach { + session.addPrompt(formatter.formatTurn(it)) + } + } + + override fun completion(message: ChatEvent.UserEvent): Flow = flow { + // 1. Prepare prompts via holistic formatter logic + session.addPrompt(formatter.formatTurn(message)) + + // 2. Setup context + val features = session.modelDefinition.features + val toolCallDef = session.modelDefinition.toolCall + + var completedParts = emptyList() + var currentText = "" + var currentThinking = "" + var currentToolCallBuffer = "" + var turnComplete = false + + // Timing state for tok/s metric + var tokenCount = 0 + val startTimeMs = System.currentTimeMillis() + + // 3. Declarative Streaming Loop + while (!turnComplete) { + var executionTriggered = false + + session.generateFlow() + .lexTags(features) + .semanticChunks() + .collect { chunk -> + // Apply immutable state transitions + when (chunk) { + is SemanticChunk.Text -> { + currentText += chunk.content + tokenCount++ + } + is SemanticChunk.Thinking -> { + currentThinking += chunk.content + tokenCount++ + } + is SemanticChunk.ToolCallContent -> currentToolCallBuffer += chunk.content + is SemanticChunk.ToolCallComplete -> { + val callDef = toolCallDef ?: throw IllegalStateException("Tool call triggered but model lacks tool definitions.") + val call = callDef.callParser(currentToolCallBuffer) + completedParts = completedParts + ChatEvent.AssistantEvent.Part.ToolCallPart(call) + currentToolCallBuffer = "" + + val results = toolCaller?.invoke(listOf(call)) ?: emptyList() + for (result in results) { + val toolResult = ChatEvent.ToolResultEvent(result) + session.addPrompt(formatter.formatTurn(toolResult)) + } + executionTriggered = true + } + } + + // Emit reactive intermediate UI state + emit( + CompletionSnapshot( + message = ChatEvent.AssistantEvent( + buildSnapshotParts(completedParts, currentText, currentThinking) + ), + isComplete = false, + isError = false, + error = null + ) + ) + } + + // If the native generator completed naturally without invoking a tool, we're done. + if (!executionTriggered) { + turnComplete = true + } + } + + // 4. Final Emit + val elapsedSeconds = (System.currentTimeMillis() - startTimeMs) / 1000f + val tokensPerSecond = if (elapsedSeconds > 0f && tokenCount > 0) tokenCount / elapsedSeconds else 0f + + completedParts = buildSnapshotParts(completedParts, currentText, currentThinking) + emit( + CompletionSnapshot( + message = ChatEvent.AssistantEvent(completedParts), + isComplete = true, + isError = false, + error = null, + tokensPerSecond = tokensPerSecond + ) + ) + } + + private fun buildSnapshotParts( + completedParts: List, + currentText: String, + currentThinking: String + ): List { + val result = completedParts.toMutableList() + if (currentThinking.isNotEmpty()) { + result.add(ChatEvent.AssistantEvent.Part.ThinkingPart(currentThinking)) + } + if (currentText.isNotEmpty()) { + result.add(ChatEvent.AssistantEvent.Part.TextPart(currentText)) + } + return result + } +} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/chat/pipeline/AllocationOptimizedScanner.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/chat/pipeline/AllocationOptimizedScanner.kt new file mode 100644 index 0000000..8124cc3 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/chat/pipeline/AllocationOptimizedScanner.kt @@ -0,0 +1,126 @@ +package com.suhel.llamabro.sdk.chat.pipeline + +/** + * An allocation-conscious DFA scanner that processes a stream of strings + * extracting distinct tags and text. + * + * Designed to minimize GC overhead by tracking a single running StringBuilder + * and yielding explicit allocations ONLY upon confirmed semantic token emission. + */ +internal class AllocationOptimizedScanner( + private val markers: List +) { + private val buffer = StringBuilder() + private var activeMarker: FeatureMarker? = null + + /** + * Feeds a raw token into the scanner and returns any definitively parsed events. + * @param token The string chunk from the LLM, or `null` to signal the end of the stream. + */ + fun feed(token: String?): List { + if (token != null) { + buffer.append(token) + } + + val events = mutableListOf() + + while (buffer.isNotEmpty()) { + if (activeMarker == null) { + var nearestMarker: FeatureMarker? = null + var nearestFullIdx = -1 + var nearestPartialIdx = -1 + + for (marker in markers) { + val fullIdx = buffer.indexOf(marker.open) + if (fullIdx != -1) { + if (nearestFullIdx == -1 || fullIdx < nearestFullIdx) { + nearestFullIdx = fullIdx + nearestMarker = marker + } + } else { + val partialIdx = findPartialMatch(buffer, marker.open) + if (partialIdx != -1) { + if (nearestPartialIdx == -1 || partialIdx < nearestPartialIdx) { + nearestPartialIdx = partialIdx + } + } + } + } + + if (nearestMarker != null) { + if (nearestFullIdx > 0) { + // Flush preceding pure text + events.add(LexerEvent.Text(buffer.substring(0, nearestFullIdx))) + } + events.add(LexerEvent.TagOpened(nearestMarker)) + activeMarker = nearestMarker + buffer.delete(0, nearestFullIdx + nearestMarker.open.length) + } else if (nearestPartialIdx != -1) { + if (nearestPartialIdx > 0) { + events.add(LexerEvent.Text(buffer.substring(0, nearestPartialIdx))) + buffer.delete(0, nearestPartialIdx) + } + // Wait for more tokens to complete or reject the partial tag boundary + break + } else { + // Entire buffer is pure text. Flush completely. + events.add(LexerEvent.Text(buffer.toString())) + buffer.clear() + break + } + } else { + val currentMarker = activeMarker!! + val closingTag = currentMarker.close + val fullIdx = buffer.indexOf(closingTag) + + if (fullIdx != -1) { + if (fullIdx > 0) { + events.add(LexerEvent.TagContent(currentMarker, buffer.substring(0, fullIdx))) + } + events.add(LexerEvent.TagClosed(currentMarker)) + activeMarker = null + buffer.delete(0, fullIdx + closingTag.length) + } else { + val partialIdx = findPartialMatch(buffer, closingTag) + if (partialIdx != -1) { + if (partialIdx > 0) { + events.add(LexerEvent.TagContent(currentMarker, buffer.substring(0, partialIdx))) + buffer.delete(0, partialIdx) + } + // Stop and safely wait for the remainder of the closing tag across future feeds + break + } else { + // Safely flush entire tag content + events.add(LexerEvent.TagContent(currentMarker, buffer.toString())) + buffer.clear() + break + } + } + } + } + + // If end of stream and activeMarker is still open, we close out forcefully or hold? + // Since LLMs can gracefully stop, the stream ending usually stops cleanly. + // We defer to pipeline if it matters. + + return events + } + + /** + * Quickly identifies if the buffer's tail hosts a prefix of the target tag. + */ + private fun findPartialMatch(buffer: CharSequence, tag: String): Int { + val searchStart = maxOf(0, buffer.length - tag.length + 1) + for (i in searchStart until buffer.length) { + var match = true + for (j in i until buffer.length) { + if (buffer[j] != tag[j - i]) { + match = false + break + } + } + if (match) return i + } + return -1 + } +} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/chat/pipeline/FeatureMarker.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/chat/pipeline/FeatureMarker.kt new file mode 100644 index 0000000..7003150 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/chat/pipeline/FeatureMarker.kt @@ -0,0 +1,9 @@ +package com.suhel.llamabro.sdk.chat.pipeline + +sealed interface FeatureMarker { + val open: String + val close: String +} + +data class ThinkingMarker(override val open: String, override val close: String) : FeatureMarker +data class ToolCallMarker(override val open: String, override val close: String) : FeatureMarker diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/chat/pipeline/LexerEvent.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/chat/pipeline/LexerEvent.kt new file mode 100644 index 0000000..d20fb30 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/chat/pipeline/LexerEvent.kt @@ -0,0 +1,8 @@ +package com.suhel.llamabro.sdk.chat.pipeline + +internal sealed interface LexerEvent { + data class Text(val content: String) : LexerEvent + data class TagOpened(val marker: FeatureMarker) : LexerEvent + data class TagContent(val marker: FeatureMarker, val content: String) : LexerEvent + data class TagClosed(val marker: FeatureMarker) : LexerEvent +} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/chat/pipeline/SemanticChunk.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/chat/pipeline/SemanticChunk.kt new file mode 100644 index 0000000..dbf6358 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/chat/pipeline/SemanticChunk.kt @@ -0,0 +1,8 @@ +package com.suhel.llamabro.sdk.chat.pipeline + +internal sealed interface SemanticChunk { + data class Text(val content: String) : SemanticChunk + data class Thinking(val content: String) : SemanticChunk + data class ToolCallContent(val content: String) : SemanticChunk + data object ToolCallComplete : SemanticChunk +} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/chat/pipeline/StreamOperators.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/chat/pipeline/StreamOperators.kt new file mode 100644 index 0000000..0605a76 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/chat/pipeline/StreamOperators.kt @@ -0,0 +1,48 @@ +package com.suhel.llamabro.sdk.chat.pipeline + +import com.suhel.llamabro.sdk.engine.TokenGenerationResult +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow + +/** + * 1. Syntax Layer: Parses raw strings into LexerEvents based on active feature markers. + * Uses an internal AllocationOptimizedScanner for high-performance, low-memory DFA scanning. + */ +internal fun Flow.lexTags(markers: List): Flow = flow { + val scanner = AllocationOptimizedScanner(markers) + collect { result -> + val events = scanner.feed(result.token) + for (event in events) { + emit(event) + } + + // When stream naturally completes, feed null to flush scanner if needed. + if (result.isComplete) { + val finalEvents = scanner.feed(null) + for (event in finalEvents) { + emit(event) + } + } + } +} + +/** + * 2. Semantic Layer: Buffers syntax events into whole textual or semantic parts. + */ +internal fun Flow.semanticChunks(): Flow = flow { + collect { event -> + when (event) { + is LexerEvent.Text -> emit(SemanticChunk.Text(event.content)) + is LexerEvent.TagContent -> { + when (event.marker) { + is ThinkingMarker -> emit(SemanticChunk.Thinking(event.content)) + is ToolCallMarker -> emit(SemanticChunk.ToolCallContent(event.content)) + } + } + is LexerEvent.TagClosed -> { + if (event.marker is ToolCallMarker) emit(SemanticChunk.ToolCallComplete) + } + is LexerEvent.TagOpened -> {} + } + } +} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/config/ModelConfig.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/config/ModelConfig.kt new file mode 100644 index 0000000..345cee8 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/config/ModelConfig.kt @@ -0,0 +1,22 @@ +package com.suhel.llamabro.sdk.config + +import com.suhel.llamabro.sdk.chat.pipeline.FeatureMarker +import com.suhel.llamabro.sdk.chat.pipeline.ThinkingMarker +import com.suhel.llamabro.sdk.format.PromptFormat +import com.suhel.llamabro.sdk.toolcall.ToolCallDefinition + +data class ModelDefinition( + val loadConfig: ModelLoadConfig, + val promptFormat: PromptFormat, + val features: List = emptyList(), + val toolCall: ToolCallDefinition? = null +) { + val supportsThinking: Boolean get() = features.any { it is ThinkingMarker } +} + +data class ModelLoadConfig( + val path: String, + val useMMap: Boolean = true, + val useMLock: Boolean = false, + val threads: Int = (Runtime.getRuntime().availableProcessors() / 2).coerceAtLeast(1), +) diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/config/SessionConfig.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/config/SessionConfig.kt new file mode 100644 index 0000000..1a128c7 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/config/SessionConfig.kt @@ -0,0 +1,35 @@ +package com.suhel.llamabro.sdk.config + +data class SessionConfig( + val contextSize: Int = 2048, + val overflowStrategy: OverflowStrategy = OverflowStrategy.RollingWindow(), + val inferenceConfig: InferenceConfig = InferenceConfig(), + val decodeConfig: DecodeConfig = DecodeConfig(), + val seed: Int = -1, +) + +data class DecodeConfig( + val batchSize: Int = 2048, + val microBatchSize: Int = 512, +) { + init { + require(batchSize >= microBatchSize) { + "batchSize ($batchSize) must be >= microBatchSize ($microBatchSize)" + } + } +} + +data class InferenceConfig( + val temperature: Float = 0.8f, + val repeatPenalty: Float = 1.0f, + val presencePenalty: Float = 0.0f, + val minP: Float? = 0.1f, + val topP: Float? = null, + val topK: Int? = null, +) + +sealed interface OverflowStrategy { + data object Halt : OverflowStrategy + data object ClearHistory : OverflowStrategy + data class RollingWindow(val dropTokens: Int = 500) : OverflowStrategy +} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/LlamaEngine.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/engine/LlamaEngine.kt similarity index 94% rename from sdk/src/main/java/com/suhel/llamabro/sdk/LlamaEngine.kt rename to sdk/src/main/java/com/suhel/llamabro/sdk/engine/LlamaEngine.kt index b37595e..5e86d59 100644 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/LlamaEngine.kt +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/engine/LlamaEngine.kt @@ -1,11 +1,11 @@ -package com.suhel.llamabro.sdk +package com.suhel.llamabro.sdk.engine +import com.suhel.llamabro.sdk.ProgressListener +import com.suhel.llamabro.sdk.config.ModelDefinition +import com.suhel.llamabro.sdk.config.SessionConfig import com.suhel.llamabro.sdk.internal.LlamaEngineImpl -import com.suhel.llamabro.sdk.internal.ProgressListener import com.suhel.llamabro.sdk.model.LlamaError import com.suhel.llamabro.sdk.model.ResourceState -import com.suhel.llamabro.sdk.model.ModelConfig -import com.suhel.llamabro.sdk.model.SessionConfig import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.flow.Flow @@ -76,7 +76,7 @@ interface LlamaEngine : AutoCloseable { * @throws LlamaError.ModelLoadFailed if the GGUF file is corrupt or incompatible. */ fun create( - modelConfig: ModelConfig, + modelConfig: ModelDefinition, onProgress: ((Float) -> Boolean)? = null ): LlamaEngine { ensureNativeLoaded() @@ -98,7 +98,7 @@ interface LlamaEngine : AutoCloseable { * @param modelConfig Path and loading options for the model. * @return A flow of [ResourceState] representing the loading lifecycle. */ - fun createFlow(modelConfig: ModelConfig): Flow> = callbackFlow { + fun createFlow(modelConfig: ModelDefinition): Flow> = callbackFlow { ensureNativeLoaded() val listener = object : ProgressListener { diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/engine/LlamaSession.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/engine/LlamaSession.kt new file mode 100644 index 0000000..69eae52 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/engine/LlamaSession.kt @@ -0,0 +1,26 @@ +package com.suhel.llamabro.sdk.engine + +import com.suhel.llamabro.sdk.chat.LlamaChatSession +import com.suhel.llamabro.sdk.config.ModelDefinition +import com.suhel.llamabro.sdk.model.ResourceState +import kotlinx.coroutines.flow.Flow + +interface LlamaSession : AutoCloseable { + val modelDefinition: ModelDefinition + + suspend fun setPrefixedPrompt(text: String) + + suspend fun addPrompt(prompt: String) + + suspend fun generate(): TokenGenerationResult + + fun generateFlow(): Flow + + suspend fun clear() + + fun abort() + + suspend fun createChatSession(systemPrompt: String): LlamaChatSession + + fun createChatSessionFlow(systemPrompt: String): Flow> +} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/engine/TokenGenerationResult.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/engine/TokenGenerationResult.kt new file mode 100644 index 0000000..3496b5f --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/engine/TokenGenerationResult.kt @@ -0,0 +1,24 @@ +package com.suhel.llamabro.sdk.engine + +data class TokenGenerationResult( + val token: String?, + val resultCode: TokenGenerationResultCode, + val isComplete: Boolean, +) + +enum class TokenGenerationResultCode(val raw: Int) { + OK(0), + MODEL_NOT_FOUND(1), + MODEL_LOAD_FAILED(2), + BACKEND_LOAD_FAILED(3), + CANCELLED(4), + CONTEXT_INIT_FAILED(10), + CONTEXT_OVERFLOW(11), + DECODE_FAILED(12), + UNKNOWN(99); + + companion object { + private val reverseMap = entries.associateBy { it.raw } + internal fun parse(raw: Int): TokenGenerationResultCode = reverseMap.getOrDefault(raw, UNKNOWN) + } +} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/internal/LlamaEngineImpl.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/engine/internal/LlamaEngineImpl.kt similarity index 77% rename from sdk/src/main/java/com/suhel/llamabro/sdk/internal/LlamaEngineImpl.kt rename to sdk/src/main/java/com/suhel/llamabro/sdk/engine/internal/LlamaEngineImpl.kt index ea76bf2..dd740e5 100644 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/internal/LlamaEngineImpl.kt +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/engine/internal/LlamaEngineImpl.kt @@ -1,12 +1,12 @@ package com.suhel.llamabro.sdk.internal -import com.suhel.llamabro.sdk.LlamaEngine -import com.suhel.llamabro.sdk.LlamaSession - +import com.suhel.llamabro.sdk.engine.LlamaEngine +import com.suhel.llamabro.sdk.engine.LlamaSession +import com.suhel.llamabro.sdk.config.ModelDefinition +import com.suhel.llamabro.sdk.config.SessionConfig +import com.suhel.llamabro.sdk.ProgressListener import com.suhel.llamabro.sdk.model.LlamaError import com.suhel.llamabro.sdk.model.ResourceState -import com.suhel.llamabro.sdk.model.ModelConfig -import com.suhel.llamabro.sdk.model.SessionConfig import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.flow.Flow @@ -21,17 +21,17 @@ import kotlinx.coroutines.withContext * bridge for session creation. */ internal class LlamaEngineImpl( - private val modelConfig: ModelConfig, + private val modelDefinition: ModelDefinition, listener: ProgressListener? = null, ) : LlamaEngine { - + /** Pointer to the native llama_bro_engine structure. */ - private val enginePtr: Long = try { + private val enginePtr: Long = run { val params = NativeCreateParams( - modelPath = modelConfig.modelPath, - useMmap = modelConfig.useMmap, - useMlock = modelConfig.useMlock, - threads = modelConfig.threads, + modelPath = modelDefinition.loadConfig.path, + useMMap = modelDefinition.loadConfig.useMMap, + useMLock = modelDefinition.loadConfig.useMLock, + threads = modelDefinition.loadConfig.threads, ) if (listener != null) { @@ -39,13 +39,11 @@ internal class LlamaEngineImpl( } else { Jni.create(params) } - } catch (e: RuntimeException) { - throw mapNativeError(e) } override suspend fun createSession(sessionConfig: SessionConfig): LlamaSession = withContext(Dispatchers.IO) { - LlamaSessionImpl(enginePtr, sessionConfig, modelConfig) + LlamaSessionImpl(enginePtr, sessionConfig, modelDefinition) } override fun createSessionFlow(sessionConfig: SessionConfig): Flow> = @@ -75,8 +73,8 @@ internal class LlamaEngineImpl( /** Helper class for passing model configuration to the Jni layer. */ private class NativeCreateParams( val modelPath: String, - val useMmap: Boolean, - val useMlock: Boolean, + val useMMap: Boolean, + val useMLock: Boolean, val threads: Int, ) diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/internal/LlamaSessionImpl.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/engine/internal/LlamaSessionImpl.kt similarity index 66% rename from sdk/src/main/java/com/suhel/llamabro/sdk/internal/LlamaSessionImpl.kt rename to sdk/src/main/java/com/suhel/llamabro/sdk/engine/internal/LlamaSessionImpl.kt index 63c76f7..4db5039 100644 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/internal/LlamaSessionImpl.kt +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/engine/internal/LlamaSessionImpl.kt @@ -1,18 +1,23 @@ package com.suhel.llamabro.sdk.internal -import com.suhel.llamabro.sdk.LlamaChatSession -import com.suhel.llamabro.sdk.LlamaSession +import com.suhel.llamabro.sdk.chat.LlamaChatSession +import com.suhel.llamabro.sdk.chat.internal.LlamaChatSessionImpl +import com.suhel.llamabro.sdk.engine.LlamaSession +import com.suhel.llamabro.sdk.config.ModelDefinition +import com.suhel.llamabro.sdk.config.OverflowStrategy +import com.suhel.llamabro.sdk.config.SessionConfig +import com.suhel.llamabro.sdk.engine.TokenGenerationResult +import com.suhel.llamabro.sdk.engine.TokenGenerationResultCode import com.suhel.llamabro.sdk.model.LlamaError -import com.suhel.llamabro.sdk.model.ModelConfig -import com.suhel.llamabro.sdk.model.OverflowStrategy import com.suhel.llamabro.sdk.model.ResourceState -import com.suhel.llamabro.sdk.model.SessionConfig -import com.suhel.llamabro.sdk.model.TokenGenerationResult import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.channelFlow +import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.callbackFlow import kotlinx.coroutines.flow.flowOn +import kotlinx.coroutines.isActive import kotlinx.coroutines.runInterruptible import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock @@ -27,12 +32,13 @@ import kotlinx.coroutines.withContext internal class LlamaSessionImpl( enginePtr: Long, sessionConfig: SessionConfig, - override val modelConfig: ModelConfig, + override val modelDefinition: ModelDefinition ) : LlamaSession { private val mutex = Mutex() + private val result = NativeTokenGenerationResult() /** Pointer to the native llama_bro_session structure. */ - private val ptr: Long = try { + private val ptr: Long = Jni.create( enginePtr = enginePtr, params = NativeCreateParams( @@ -58,32 +64,21 @@ internal class LlamaSessionImpl( microBatchSize = sessionConfig.decodeConfig.microBatchSize, ) ) - } catch (e: RuntimeException) { - throw mapNativeError(e) - } - override suspend fun setSystemPrompt(text: String, addSpecial: Boolean) = + override suspend fun setPrefixedPrompt(text: String) = withContext(Dispatchers.IO) { mutex.withLock { - try { - runInterruptible { - Jni.setSystemPrompt(ptr, text, addSpecial) - } - } catch (e: RuntimeException) { - throw mapNativeError(e) + runInterruptible { + Jni.setSystemPrompt(ptr, text) } } } - override suspend fun ingestPrompt(prompt: String, addSpecial: Boolean) = + override suspend fun addPrompt(prompt: String) = withContext(Dispatchers.IO) { mutex.withLock { - try { - runInterruptible { - Jni.ingestPrompt(ptr, prompt, addSpecial) - } - } catch (e: RuntimeException) { - throw mapNativeError(e) + runInterruptible { + Jni.addUserPrompt(ptr, prompt) } } } @@ -91,30 +86,43 @@ internal class LlamaSessionImpl( override suspend fun generate(): TokenGenerationResult = withContext(Dispatchers.IO) { mutex.withLock { - try { - runInterruptible { - Jni.generate(ptr).let { - TokenGenerationResult( - token = it.token, - isComplete = it.isComplete, - ) - } + runInterruptible { + Jni.generate(ptr, result) + + TokenGenerationResult( + token = result.token, + resultCode = TokenGenerationResultCode.parse(result.resultCode), + isComplete = result.isComplete, + ) + } + } + } + + override fun generateFlow(): Flow = channelFlow { + withContext(Dispatchers.IO) { + mutex.withLock { + var isDone = false + while (!isDone && kotlinx.coroutines.currentCoroutineContext().isActive) { + val genResult = runInterruptible { + Jni.generate(ptr, result) + TokenGenerationResult( + token = result.token, + resultCode = TokenGenerationResultCode.parse(result.resultCode), + isComplete = result.isComplete, + ) } - } catch (e: RuntimeException) { - throw mapNativeError(e) + send(genResult) + isDone = genResult.isComplete || genResult.resultCode != TokenGenerationResultCode.OK } } } + } override suspend fun clear() = withContext(Dispatchers.IO) { mutex.withLock { - try { - runInterruptible { - Jni.clear(ptr) - } - } catch (e: RuntimeException) { - throw mapNativeError(e) + runInterruptible { + Jni.clear(ptr) } } } @@ -159,19 +167,22 @@ internal class LlamaSessionImpl( val topP: Float, val minPEnabled: Boolean, val minP: Float, + // Always-on (no enable field) val repPen: Float, val presencePen: Float, val temp: Float, val seed: Int, + // Decode tuning val batchSize: Int, val microBatchSize: Int, ) private class NativeTokenGenerationResult( - val token: String?, - val isComplete: Boolean, + var token: String? = null, + var resultCode: Int = 0, + var isComplete: Boolean = false, ) private object Jni { @@ -179,10 +190,10 @@ internal class LlamaSessionImpl( external fun create(enginePtr: Long, params: NativeCreateParams): Long @JvmStatic - external fun setSystemPrompt(sessionPtr: Long, text: String, addSpecial: Boolean) + external fun setSystemPrompt(sessionPtr: Long, prompt: String) @JvmStatic - external fun ingestPrompt(sessionPtr: Long, text: String, addSpecial: Boolean) + external fun addUserPrompt(sessionPtr: Long, prompt: String) @JvmStatic external fun clear(sessionPtr: Long) @@ -191,7 +202,7 @@ internal class LlamaSessionImpl( external fun abort(sessionPtr: Long) @JvmStatic - external fun generate(sessionPtr: Long): NativeTokenGenerationResult + external fun generate(sessionPtr: Long, result: NativeTokenGenerationResult) @JvmStatic external fun destroy(sessionPtr: Long) diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/format/PromptDecorator.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/format/PromptDecorator.kt new file mode 100644 index 0000000..649a4f3 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/format/PromptDecorator.kt @@ -0,0 +1,14 @@ +package com.suhel.llamabro.sdk.format + +import com.suhel.llamabro.sdk.models.ChatEvent + +/** + * A capability decorator injects template prefixes for prompts based on features. + * Used by PromptFormatter to decouple formatting logic. + */ +interface PromptDecorator { + fun decorateSystem(content: String): String = content + + // Allows decorators to format their specific Part types + fun decorateAssistantPart(part: ChatEvent.AssistantEvent.Part): String? = null +} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/format/PromptFormat.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/format/PromptFormat.kt new file mode 100644 index 0000000..4da83f3 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/format/PromptFormat.kt @@ -0,0 +1,10 @@ +package com.suhel.llamabro.sdk.format + +data class PromptFormat( + val systemPrefix: String, + val userPrefix: String, + val assistantPrefix: String, + val endOfTurn: String, + val emitAssistantPrefixOnGeneration: Boolean = true, + val stopStrings: List = emptyList(), +) diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/format/PromptFormats.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/format/PromptFormats.kt new file mode 100644 index 0000000..a64cb12 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/format/PromptFormats.kt @@ -0,0 +1,87 @@ +package com.suhel.llamabro.sdk.format + +/** + * Out-of-the-box prompt templates for leading open-source model families. + * + * Each entry faithfully reproduces the template specified in the model's official documentation. + * Using the wrong template will silently degrade generation quality, so choose carefully. + */ +object PromptFormats { + + /** ChatML — used by SmolLM2, Qwen 2.5, and other ChatML-compatible models. */ + val CHAT_ML = PromptFormat( + systemPrefix = "<|im_start|>system\n", + userPrefix = "<|im_start|>user\n", + assistantPrefix = "<|im_start|>assistant\n", + endOfTurn = "<|im_end|>\n", + emitAssistantPrefixOnGeneration = true, + stopStrings = listOf("<|im_end|>") + ) + + /** Meta Llama 3.x / 3.1 / 3.2 family. */ + val LLAMA_3 = PromptFormat( + systemPrefix = "<|start_header_id|>system<|end_header_id|>\n\n", + userPrefix = "<|start_header_id|>user<|end_header_id|>\n\n", + assistantPrefix = "<|start_header_id|>assistant<|end_header_id|>\n\n", + endOfTurn = "<|eot_id|>", + emitAssistantPrefixOnGeneration = true, + stopStrings = listOf("<|eot_id|>", "<|eom_id|>") + ) + + /** Mistral Instruct family. */ + val MISTRAL = PromptFormat( + systemPrefix = "", + userPrefix = "[INST] ", + assistantPrefix = "", + endOfTurn = " [/INST]", + emitAssistantPrefixOnGeneration = true, + stopStrings = listOf("", "[/INST]") + ) + + /** Google Gemma / Gemma 2 family. System messages are prepended to the first user turn. */ + val GEMMA = PromptFormat( + systemPrefix = "user\nSystem: ", + userPrefix = "user\n", + assistantPrefix = "model\n", + endOfTurn = "\n", + emitAssistantPrefixOnGeneration = true, + stopStrings = listOf("") + ) + + /** DeepSeek R1 / R1-Distill family. */ + val DEEPSEEK_R1 = PromptFormat( + systemPrefix = "<|begin of sentence|>", + userPrefix = "User: ", + assistantPrefix = "Assistant: ", + endOfTurn = "<|end of sentence|>", + emitAssistantPrefixOnGeneration = true, + stopStrings = listOf("<|end of sentence|>") + ) + + /** Qwen 2.5 family (identical to ChatML). */ + val QWEN_2_5 = CHAT_ML + + /** + * NVIDIA Nemotron family. + * + * Uses the `` / `` sentinel format as specified in the official + * [Nemotron-Mini-4B-Instruct](https://huggingface.co/nvidia/Nemotron-Mini-4B-Instruct) card. + * + * Template: + * ``` + * System + * {system_prompt} + * User + * {user_message} + * Assistant\n + * ``` + */ + val NEMOTRON = PromptFormat( + systemPrefix = "System\n", + userPrefix = "User\n", + assistantPrefix = "Assistant\n", + endOfTurn = "\n", + emitAssistantPrefixOnGeneration = true, + stopStrings = listOf("") + ) +} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/format/PromptFormatter.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/format/PromptFormatter.kt new file mode 100644 index 0000000..ccc344b --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/format/PromptFormatter.kt @@ -0,0 +1,89 @@ +package com.suhel.llamabro.sdk.format + +import com.suhel.llamabro.sdk.config.ModelDefinition +import com.suhel.llamabro.sdk.models.ChatEvent +import com.suhel.llamabro.sdk.chat.pipeline.ThinkingMarker + +internal class PromptFormatter( + private val modelDefinition: ModelDefinition, + private val decorators: List = emptyList() +) { + private val template = modelDefinition.promptFormat + + /** + * Governs the exact handshake strings required to turn a semantic chat event into a raw prompt. + */ + fun formatTurn(event: ChatEvent): String { + return when (event) { + is ChatEvent.SystemEvent -> formatSystem(event) + is ChatEvent.UserEvent -> formatUserTurnStart(event) + is ChatEvent.AssistantEvent -> formatAssistant(event) + is ChatEvent.ToolResultEvent -> formatToolResult(event) + } + } + + private fun formatSystem(event: ChatEvent.SystemEvent): String { + var content = event.content + for (decorator in decorators) { + content = decorator.decorateSystem(content) + } + return buildString { + append(template.systemPrefix) + append(content) + append(template.endOfTurn) + } + } + + private fun formatUserTurnStart(event: ChatEvent.UserEvent): String { + return buildString { + append(template.userPrefix) + append(event.content) + append(template.endOfTurn) + + if (template.emitAssistantPrefixOnGeneration) { + append(template.assistantPrefix) + } + + if (event.think) { + val thinkingMarker = modelDefinition.features.filterIsInstance().firstOrNull() + if (thinkingMarker != null) { + append(thinkingMarker.open) + append("\n") + } + } + } + } + + private fun formatAssistant(event: ChatEvent.AssistantEvent): String { + return buildString { + if (!template.emitAssistantPrefixOnGeneration) { + append(template.assistantPrefix) + } + + for (part in event.parts) { + var decorated: String? = null + for (decorator in decorators) { + val d = decorator.decorateAssistantPart(part) + if (d != null) { + decorated = d + break + } + } + if (decorated != null) { + append(decorated) + } else if (part is ChatEvent.AssistantEvent.Part.TextPart) { + append(part.content) + } + } + append(template.endOfTurn) + } + } + + private fun formatToolResult(event: ChatEvent.ToolResultEvent): String { + return buildString { + append(template.userPrefix) + append(event.result.toString()) + append(template.endOfTurn) + } + } +} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/format/ThinkingDecorator.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/format/ThinkingDecorator.kt new file mode 100644 index 0000000..d0e338a --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/format/ThinkingDecorator.kt @@ -0,0 +1,21 @@ +package com.suhel.llamabro.sdk.format + +import com.suhel.llamabro.sdk.models.ChatEvent +import com.suhel.llamabro.sdk.chat.pipeline.ThinkingMarker + +class ThinkingDecorator( + private val thinking: ThinkingMarker +) : PromptDecorator { + + override fun decorateAssistantPart(part: ChatEvent.AssistantEvent.Part): String? { + if (part !is ChatEvent.AssistantEvent.Part.ThinkingPart) return null + return buildString { + append(thinking.open) + append("\n") + append(part.content) + append("\n") + append(thinking.close) + append("\n") + } + } +} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/format/ToolCallDecorator.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/format/ToolCallDecorator.kt new file mode 100644 index 0000000..9fcb4c3 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/format/ToolCallDecorator.kt @@ -0,0 +1,28 @@ +package com.suhel.llamabro.sdk.format + +import com.suhel.llamabro.sdk.models.ChatEvent +import com.suhel.llamabro.sdk.toolcall.ToolDefinition +import com.suhel.llamabro.sdk.toolcall.ToolCallDefinition + +class ToolCallDecorator( + private val toolCallDefinition: ToolCallDefinition, + private val tools: List = emptyList() +) : PromptDecorator { + + override fun decorateSystem(content: String): String { + val toolsText = if (tools.isNotEmpty()) { + toolCallDefinition.definitionFormatter(tools) + } else { + "" + } + + if (toolsText.isEmpty()) return content + + return "$content\n$toolsText" + } + + override fun decorateAssistantPart(part: ChatEvent.AssistantEvent.Part): String? { + if (part !is ChatEvent.AssistantEvent.Part.ToolCallPart) return null + return toolCallDefinition.callSerializer(part.call) + } +} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/internal/LlamaChatSessionImpl.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/internal/LlamaChatSessionImpl.kt deleted file mode 100644 index faf6876..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/internal/LlamaChatSessionImpl.kt +++ /dev/null @@ -1,239 +0,0 @@ -package com.suhel.llamabro.sdk.internal - -import com.suhel.llamabro.sdk.LlamaChatSession -import com.suhel.llamabro.sdk.LlamaSession -import com.suhel.llamabro.sdk.model.Completion -import com.suhel.llamabro.sdk.model.LlamaError -import com.suhel.llamabro.sdk.model.Message -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.currentCoroutineContext -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.flow.flowOn -import kotlinx.coroutines.flow.onCompletion -import kotlinx.coroutines.isActive -import kotlinx.coroutines.withContext - -internal class LlamaChatSessionImpl( - private val session: LlamaSession, - private val systemPrompt: String -) : LlamaChatSession { - private val fmt = session.modelConfig.promptFormat - private val parser = TokenStreamParser( - thinkingStart = fmt.thinkStart, - thinkingEnd = fmt.thinkEnd, - stopStrings = fmt.stopStrings, - ) - private val prompter = Prompter(fmt) - - override val supportsThinking: Boolean - get() = session.modelConfig.supportsThinking - - override fun completion( - prompt: String, - enableThinking: Boolean, - maxThinkingTokens: Int?, - ): Flow = flow { - var completionState = Completion() - var tokenCount = 0 - var thinkingTokenCount = 0 - val contentBuilder = StringBuilder() - val thinkingBuilder = StringBuilder() - val thinkingEnabled = enableThinking && session.modelConfig.supportsThinking - - parser.reset(thinkingEnabled) - - val formattedPrompt = buildString { - append(prompter.user(prompt)) - append(prompter.assistantStart()) - - if(thinkingEnabled) { - append(prompter.thinkingStart()) - } - } - - session.ingestPrompt( - prompt = formattedPrompt, - addSpecial = prompter.shouldAddSpecial() - ) - - val startTime = System.nanoTime() - - while (currentCoroutineContext().isActive) { - val generation = try { - session.generate() - } catch (_: LlamaError.Cancelled) { - emit( - completionState.finalize( - tokenCount = tokenCount, - startTime = startTime, - isInterrupted = true, - contentBuilder = contentBuilder, - thinkingBuilder = thinkingBuilder - ) - ) - return@flow - } catch (_: LlamaError.ContextOverflow) { - // Context exhausted and strategy cannot recover — surface as an interrupted completion. - emit( - completionState.finalize( - tokenCount = tokenCount, - startTime = startTime, - isInterrupted = true, - contentBuilder = contentBuilder, - thinkingBuilder = thinkingBuilder - ) - ) - return@flow - } catch (e: LlamaError) { - // Fatal errors (DecodeFailed, NativeException, etc.) are emitted as data, - // not thrown — the flow always terminates cleanly. - emit( - completionState.finalize( - tokenCount = tokenCount, - startTime = startTime, - isInterrupted = false, - contentBuilder = contentBuilder, - thinkingBuilder = thinkingBuilder, - error = e - ) - ) - return@flow - } - - generation.token?.let { token -> - tokenCount++ - if (parser.isThinking) thinkingTokenCount++ - - val contentLenBefore = contentBuilder.length - val thinkingLenBefore = thinkingBuilder.length - val stateBefore = parser.isThinking - - // The parser directly modifies the builders. 0 allocations. - parser.process(token, contentBuilder, thinkingBuilder) - - // Only emit a new state if the parser actually appended text or flipped state. - if ( - contentBuilder.length > contentLenBefore || - thinkingBuilder.length > thinkingLenBefore || - parser.isThinking != stateBefore - ) { - completionState = completionState.copy( - contentText = if (contentBuilder.isEmpty()) null else contentBuilder.toString(), - thinkingText = if (thinkingBuilder.isEmpty()) null else thinkingBuilder.toString() - ) - emit(completionState) - } - - // Stop string detected — treat as a clean end of generation. - if (parser.isStopped) { - emit( - completionState.finalize( - tokenCount = tokenCount, - startTime = startTime, - isInterrupted = false, - contentBuilder = contentBuilder, - thinkingBuilder = thinkingBuilder - ) - ) - return@flow - } - } - - // Thinking budget exhausted — force-close the thinking block so the model - // begins its response. Done outside the token `let` to allow suspension. - if ( - parser.isThinking && - maxThinkingTokens != null && - thinkingTokenCount >= maxThinkingTokens - ) { - val closeTag = prompter.thinkingEnd() - session.ingestPrompt(closeTag, addSpecial = false) - parser.process(closeTag, contentBuilder, thinkingBuilder) - completionState = completionState.copy( - thinkingText = thinkingBuilder.ifBlank { null }?.toString()?.trim() - ) - emit(completionState) - } - - if (generation.isComplete) { - parser.flush(contentBuilder, thinkingBuilder) - emit( - completionState.finalize( - tokenCount = tokenCount, - startTime = startTime, - isInterrupted = false, - contentBuilder = contentBuilder, - thinkingBuilder = thinkingBuilder - ) - ) - break - } - } - } - .onCompletion { cause -> - if (cause != null) { - session.abort() - } - } - .flowOn(Dispatchers.IO) - - /** Finalizes completion state with performance metrics and trimming. */ - private fun Completion.finalize( - tokenCount: Int, - startTime: Long, - isInterrupted: Boolean, - contentBuilder: StringBuilder, - thinkingBuilder: StringBuilder, - error: LlamaError? = null, - ): Completion { - val endTime = System.nanoTime() - val durationNs = (endTime - startTime).coerceAtLeast(1) - val tps = (tokenCount.toDouble() / durationNs * 1e9).toFloat() - - return this.copy( - thinkingText = thinkingBuilder.ifBlank { null }?.toString()?.trim(), - contentText = contentBuilder.ifBlank { null }?.toString()?.trim(), - tokensPerSecond = tps, - isComplete = true, - isInterrupted = isInterrupted, - error = error, - ) - } - - override suspend fun reset() = - withContext(Dispatchers.IO) { - session.clear() - } - - /** - * Ingests [messages] into the session, oldest-first. If the context fills up mid-load, - * the oldest message is dropped and the remaining slice is retried from the system-prompt - * boundary — ensuring the most recent history always fits. - */ - override suspend fun loadHistory(messages: List) = - withContext(Dispatchers.IO) { - var start = 0 - retry@ while (start < messages.size) { - session.clear() - for (i in start until messages.size) { - try { - session.ingestPrompt(prompter.format(messages[i])) - } catch (_: LlamaError.ContextOverflow) { - start++ - continue@retry - } - } - return@withContext - } - } - - /** Initial injection of the BOS and system prompt during session creation. */ - internal suspend fun initialize() = - withContext(Dispatchers.IO) { - session.setSystemPrompt( - text = prompter.system(systemPrompt), - addSpecial = prompter.shouldAddSpecial() - ) - } -} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/internal/NativeErrorMapper.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/internal/NativeErrorMapper.kt deleted file mode 100644 index 7b7efcf..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/internal/NativeErrorMapper.kt +++ /dev/null @@ -1,29 +0,0 @@ -package com.suhel.llamabro.sdk.internal - -import com.suhel.llamabro.sdk.model.LlamaError - -/** - * Maps a [RuntimeException] thrown by JNI into a typed [LlamaError]. - * - * The JNI layer encodes errors as ":" in the - * exception message. This function parses that encoding and returns the - * appropriate SDK exception. - */ -internal fun mapNativeError(e: RuntimeException): LlamaError { - val message = e.message ?: return LlamaError.NativeException("Unknown native error", e) - - val colonIdx = message.indexOf(':') - val code = if (colonIdx > 0) message.substring(0, colonIdx).toIntOrNull() else null - val detail = if (colonIdx > 0) message.substring(colonIdx + 1) else message - - return when (code) { - 1 -> LlamaError.ModelNotFound(detail) - 2 -> LlamaError.ModelLoadFailed(detail, e) - 3 -> LlamaError.BackendLoadFailed(detail) - 4 -> LlamaError.Cancelled() - 10 -> LlamaError.ContextInitFailed(e) - 11 -> LlamaError.ContextOverflow() - 12 -> LlamaError.DecodeFailed(detail.toIntOrNull() ?: -1) - else -> LlamaError.NativeException(detail, e) - } -} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/internal/ProgressListener.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/internal/ProgressListener.kt deleted file mode 100644 index 4c61711..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/internal/ProgressListener.kt +++ /dev/null @@ -1,17 +0,0 @@ -package com.suhel.llamabro.sdk.internal - -/** - * Internal JNI callback invoked by the native layer during model loading. - * - * Each invocation corresponds to a progress update from the llama.cpp model - * loader, typically reflecting the percentage of the model file read into memory. - */ -internal interface ProgressListener { - /** - * Called by the native layer with the current loading progress. - * - * @param progress A value from 0.0 to 1.0 reflecting loading status. - * @return true to continue loading, false to abort the operation. - */ - fun onProgress(progress: Float): Boolean -} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/internal/Prompter.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/internal/Prompter.kt deleted file mode 100644 index 1735a29..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/internal/Prompter.kt +++ /dev/null @@ -1,67 +0,0 @@ -package com.suhel.llamabro.sdk.internal - -import com.suhel.llamabro.sdk.model.Message -import com.suhel.llamabro.sdk.model.PromptFormat - -/** - * Internal utility to wrap raw message text into model-specific chat templates. - * - * This class ensures that System, User, and Assistant messages are prefixed - * and suffixed correctly according to the [PromptFormat] provided by the model. - * - * The assistant turn has an explicit lifecycle: - * - [assistantStart] — opens the turn (prefix only, used before generation). - * - [assistantEnd] — closes the turn (suffix + eos, used after generation). - * - [assistant] — formats a complete turn (prefix + content + suffix + eos, - * used for history loading). - */ -internal class Prompter(private val fmt: PromptFormat) { - /** Beginning of stream token. */ - fun bos(): String = fmt.bos.orEmpty() - - /** End of stream token. */ - fun eos(): String = fmt.eos.orEmpty() - - /** - * Whether the tokenizer should prepend the model's native BOS token. - * - * When the [PromptFormat] supplies an explicit [PromptFormat.bos] string, - * we embed it ourselves in [system] and tell the tokenizer NOT to add another. - * When it is null, we let the tokenizer handle BOS automatically. - */ - fun shouldAddSpecial(): Boolean = fmt.bos == null - - /** Formats a single user message (complete turn). */ - fun user(prompt: String): String = - "${fmt.userPrefix}$prompt${fmt.userSuffix}" - - /** Formats a single system instruction (complete turn with BOS). */ - fun system(prompt: String): String = - "${bos()}${fmt.systemPrefix}$prompt${fmt.systemSuffix}" - - fun thinkingStart(): String = "${fmt.thinkStart}\n" - - fun thinkingEnd(): String = "${fmt.thinkEnd}\n" - - /** Returns the assistant turn opening prefix (injected before generation). */ - fun assistantStart(): String = fmt.assistantPrefix - - /** Returns the assistant turn closing tokens (injected after generation). */ - fun assistantEnd(): String = "${fmt.assistantSuffix}${eos()}" - - /** Formats a complete assistant message (used for history loading). */ - fun assistant(prompt: String, thinking: String? = null): String { - val thinkingBlock = if (!thinking.isNullOrBlank()) { - "${thinkingStart()}$thinking${thinkingEnd()}" - } else "" - - return "${assistantStart()}$thinkingBlock$prompt${assistantEnd()}" - } - - /** Dispatches formatting based on the message role. */ - fun format(message: Message): String = - when (message) { - is Message.User -> user(message.content) - is Message.Assistant -> assistant(message.content, message.thinking) - } -} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/internal/TokenStreamParser.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/internal/TokenStreamParser.kt deleted file mode 100644 index 441d869..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/internal/TokenStreamParser.kt +++ /dev/null @@ -1,139 +0,0 @@ -package com.suhel.llamabro.sdk.internal - -internal class TokenStreamParser( - private val thinkingStart: String = DEFAULT_THINKING_START, - private val thinkingEnd: String = DEFAULT_THINKING_END, - private val stopStrings: List = emptyList(), -) { - private val buffer = StringBuilder( - maxOf( - thinkingStart.length, - thinkingEnd.length, - stopStrings.maxOfOrNull { it.length } ?: 0, - ) - ) - - var isThinking = false - private set - - /** True once a stop string has been matched. Further [process] calls are no-ops. */ - var isStopped = false - private set - - fun process( - token: String, - contentBuilder: StringBuilder, - thinkingBuilder: StringBuilder, - ) { - if (isStopped) return - buffer.append(token) - - while (true) { - if (!isThinking) { - val tagIdx = buffer.indexOf(thinkingStart) - val stopMatch = findEarliestStop() - - when { - // Stop string fires at or before the think tag — stop wins on a tie. - stopMatch != null && (tagIdx == -1 || stopMatch.first <= tagIdx) -> { - if (stopMatch.first > 0) contentBuilder.append(buffer, 0, stopMatch.first) - buffer.clear() - isStopped = true - return - } - - // Think-start found with no earlier stop string. - tagIdx != -1 -> { - if (tagIdx > 0) contentBuilder.append(buffer, 0, tagIdx) - isThinking = true - buffer.delete(0, tagIdx + thinkingStart.length) - continue - } - - // No full match — hold back as much as could be the start of any pattern. - else -> { - val hold = maxOf( - getPartialMatchLength(buffer, thinkingStart), - stopStrings.maxOfOrNull { getPartialMatchLength(buffer, it) } ?: 0, - ) - val safe = buffer.length - hold - if (safe > 0) { - contentBuilder.append(buffer, 0, safe) - buffer.delete(0, safe) - } - break - } - } - } else { - // In thinking mode: only watch for the closing tag. - val endIdx = buffer.indexOf(thinkingEnd) - if (endIdx != -1) { - if (endIdx > 0) thinkingBuilder.append(buffer, 0, endIdx) - isThinking = false - buffer.delete(0, endIdx + thinkingEnd.length) - continue - } else { - val hold = getPartialMatchLength(buffer, thinkingEnd) - val safe = buffer.length - hold - if (safe > 0) { - thinkingBuilder.append(buffer, 0, safe) - buffer.delete(0, safe) - } - break - } - } - } - } - - fun flush(contentBuilder: StringBuilder, thinkingBuilder: StringBuilder) { - if (buffer.isNotEmpty()) { - val dest = if (isThinking) thinkingBuilder else contentBuilder - dest.append(buffer) - buffer.clear() - } - } - - fun reset(startThinking: Boolean = false) { - buffer.clear() - isThinking = startThinking - isStopped = false - } - - /** Returns the (startIndex, length) of the earliest stop-string match, or null. */ - private fun findEarliestStop(): Pair? { - if (stopStrings.isEmpty()) return null - var result: Pair? = null - for (ss in stopStrings) { - val idx = buffer.indexOf(ss) - if (idx != -1 && (result == null || idx < result.first)) { - result = idx to ss.length - } - } - return result - } - - /** - * Returns how many characters at the END of [sb] could be the beginning of [pattern]. - * Bounded by pattern.length - 1, so it is O(pattern.length) in the worst case. - */ - private fun getPartialMatchLength(sb: StringBuilder, pattern: String): Int { - val maxOverlap = minOf(sb.length, pattern.length - 1) - if (maxOverlap == 0) return 0 - for (i in (sb.length - maxOverlap) until sb.length) { - var isMatch = true - for (j in 0 until (sb.length - i)) { - if (sb[i + j] != pattern[j]) { - isMatch = false - break - } - } - if (isMatch) return sb.length - i - } - return 0 - } - - companion object { - private const val DEFAULT_THINKING_START = "" - private const val DEFAULT_THINKING_END = "" - } -} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/model/Completion.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/model/Completion.kt deleted file mode 100644 index 796490e..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/model/Completion.kt +++ /dev/null @@ -1,38 +0,0 @@ -package com.suhel.llamabro.sdk.model - -/** - * An accumulated snapshot of a response generated by the LLM. - * - * In a streaming context (e.g., [com.suhel.llamabro.sdk.LlamaChatSession.completion]), - * this object is emitted multiple times. Each emission represents the state of the - * response at that point in time, with [contentText] and [thinkingText] containing - * all text generated so far. - * - * ### Terminal state - * The final emission always has [isComplete] = true. Inspect [isInterrupted] and [error] - * to determine how generation ended: - * - * | [isComplete] | [isInterrupted] | [error] | Meaning | - * |---|---|---|---| - * | true | false | null | Natural EOS or stop-string hit | - * | true | true | null | Cancelled by user or context overflow | - * | true | false | non-null| Fatal error (decode failure, native crash) | - * - * The [com.suhel.llamabro.sdk.LlamaChatSession.completion] flow **never throws**; - * all failure conditions are surfaced here so callers need only collect the flow. - * - * @property thinkingText Accumulated reasoning text (inside `...` blocks). - * @property contentText Accumulated final response text. - * @property tokensPerSecond Generation speed. Only populated on the final emission. - * @property isComplete True on the last emission. - * @property isInterrupted True if generation was cut short by cancellation or overflow. - * @property error Non-null only for fatal errors (e.g. [LlamaError.DecodeFailed]). - */ -data class Completion( - val thinkingText: String? = null, - val contentText: String? = null, - val tokensPerSecond: Float? = null, - val isComplete: Boolean = false, - val isInterrupted: Boolean = false, - val error: LlamaError? = null, -) diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/model/DecodeConfig.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/model/DecodeConfig.kt deleted file mode 100644 index 1020889..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/model/DecodeConfig.kt +++ /dev/null @@ -1,25 +0,0 @@ -package com.suhel.llamabro.sdk.model - -/** - * Low-level llama.cpp context tuning parameters. - * - * The defaults are optimized for typical Android on-device inference. Only adjust - * these if you have a specific performance profile or memory constraint in mind. - * - * @property batchSize Maximum number of tokens processed in a single llama_decode - * call (n_batch). Larger values speed up prompt ingestion - * (pre-fill) at the cost of peak memory usage. - * @property microBatchSize Physical GGML compute batch size (n_ubatch). Must be less - * than or equal to [batchSize]. Affects SIMD register - * utilization and memory bandwidth. - */ -data class DecodeConfig( - val batchSize: Int = 2048, - val microBatchSize: Int = 512, -) { - init { - require(batchSize >= microBatchSize) { - "batchSize ($batchSize) must be >= microBatchSize ($microBatchSize)" - } - } -} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/model/InferenceConfig.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/model/InferenceConfig.kt deleted file mode 100644 index 751f670..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/model/InferenceConfig.kt +++ /dev/null @@ -1,30 +0,0 @@ -package com.suhel.llamabro.sdk.model - -/** - * Parameters for controlling token generation and sampling. - * - * These settings influence the "creativity" and structure of the model's response. - * - * @property temperature Controls the randomness of predictions. Lower values (e.g., 0.1) - * make the output more deterministic (greedy), while higher - * values (e.g., 1.2) make it more diverse and creative. - * @property repeatPenalty Discourages the model from repeating the same sequence of - * tokens. 1.0 means no penalty. - * @property presencePenalty Penalizes tokens based on whether they have already appeared - * in the generated text. Higher values encourage the model to - * talk about new topics. - * @property minP A threshold for sampling. Only tokens with a probability - * relative to the most likely token greater than this - * value are considered. This is often preferred over Top-P. - * @property topP Nucleus sampling: only the smallest set of most probable - * tokens whose cumulative probability exceeds P are considered. - * @property topK Only the K most likely tokens are considered for sampling. - */ -data class InferenceConfig( - val temperature: Float = 0.8f, - val repeatPenalty: Float = 1.1f, - val presencePenalty: Float = 0.0f, - val minP: Float? = 0.1f, - val topP: Float? = null, - val topK: Int? = null, -) diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/model/LlamaError.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/model/LlamaError.kt index 4b016e6..c0391f2 100644 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/model/LlamaError.kt +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/model/LlamaError.kt @@ -37,7 +37,7 @@ sealed class LlamaError(message: String, cause: Throwable? = null) : Exception(m /** * The operation was explicitly cancelled or aborted. - * This is thrown if [com.suhel.llamabro.sdk.LlamaSession.abort] is called or + * This is thrown if [com.suhel.llamabro.sdk.LlamaSession.abort] is called or * if the coroutine is cancelled during a native loop. */ class Cancelled : @@ -48,14 +48,14 @@ sealed class LlamaError(message: String, cause: Throwable? = null) : Exception(m /** * llama_init_from_model returned null. * This typically indicates an Out Of Memory (OOM) condition or an invalid - * [SessionConfig]. + * [com.suhel.llamabro.sdk.SessionConfig]. */ class ContextInitFailed(cause: Throwable? = null) : LlamaError("Failed to initialize inference context", cause) /** - * The context window is full and the configured [OverflowStrategy] cannot recover. - * This is only thrown by [OverflowStrategy.Halt] — other strategies handle + * The context window is full and the configured [com.suhel.llamabro.sdk.OverflowStrategy] cannot recover. + * This is only thrown by [com.suhel.llamabro.sdk.OverflowStrategy.Halt] — other strategies handle * this silently by clearing or shifting history. */ class ContextOverflow : diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/model/Message.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/model/Message.kt deleted file mode 100644 index 0708a80..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/model/Message.kt +++ /dev/null @@ -1,36 +0,0 @@ -package com.suhel.llamabro.sdk.model - -/** - * Represents a single message in a conversation history. - * - * This sealed interface defines the two roles supported by the SDK: [User] and [Assistant]. - * These messages are used by [com.suhel.llamabro.sdk.LlamaChatSession.loadHistory] to - * restore state and are formatted according to the model's [PromptFormat]. - */ -sealed interface Message { - /** The textual content of the message. */ - val content: String - - /** - * A message sent by the human participant. - * - * @param content The raw text from the user. - */ - data class User(override val content: String) : Message - - /** - * A message generated by the AI model. - * - * @param content The main response text generated by the model. - * @param thinking The model's internal reasoning process. This is only - * populated for reasoning models (e.g., DeepSeek-R1) - * that output `...` blocks. - * @param tokensPerSecond The generation speed for this specific message. - * Null if the message was restored from history. - */ - data class Assistant( - override val content: String, - val thinking: String? = null, - val tokensPerSecond: Float? = null, - ) : Message -} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/model/ModelConfig.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/model/ModelConfig.kt deleted file mode 100644 index 3c3d27c..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/model/ModelConfig.kt +++ /dev/null @@ -1,31 +0,0 @@ -package com.suhel.llamabro.sdk.model - -/** - * Configuration for loading a GGUF model into a [com.suhel.llamabro.sdk.LlamaEngine]. - * - * This class defines how the model file is read and held in memory. Choosing the - * right parameters here is crucial for performance and stability on mobile devices. - * - * @property modelPath The absolute filesystem path to the `.gguf` model file. - * @property promptFormat The chat template that matches how the model was trained. - * Use [PromptFormats] for common models. - * @property useMmap If true, the model is memory-mapped. This allows faster - * initialization and lets the OS manage memory paging. - * Usually set to true. - * @property useMlock If true, the model's memory pages are locked in RAM. - * This prevents the OS from swapping them out, ensuring - * consistent performance but potentially causing OOMs if - * the model is larger than available RAM. - * @property supportsThinking Whether the model supports thinking. - * @property threads The number of CPU threads to use for compute-heavy - * inference tasks. Defaults to half the available cores - * to balance performance and battery/thermal impact. - */ -data class ModelConfig( - val modelPath: String, - val promptFormat: PromptFormat, - val supportsThinking: Boolean = false, - val useMmap: Boolean = true, - val useMlock: Boolean = false, - val threads: Int = (Runtime.getRuntime().availableProcessors() / 2).coerceAtLeast(1), -) diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/model/OverflowStrategy.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/model/OverflowStrategy.kt deleted file mode 100644 index 1cb75c4..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/model/OverflowStrategy.kt +++ /dev/null @@ -1,39 +0,0 @@ -package com.suhel.llamabro.sdk.model - -/** - * Defines how the session behaves when the context window fills up. - * - * The context window is the finite memory buffer ([SessionConfig.contextSize]) - * that holds the system prompt, conversation history, and the new tokens being - * generated. When this limit is reached, a strategy must be chosen to make - * room for more tokens. - */ -sealed interface OverflowStrategy { - /** - * Halts generation and throws a [LlamaError.ContextOverflow]. - * - * This is useful for strict data extraction or summarization tasks where - * losing any part of the input context would invalidate the result. - */ - data object Halt : OverflowStrategy - - /** - * Clears the entire conversation history, keeping only the system prompt. - * - * This starts the conversation from a clean slate once the limit is hit. - * Note that the user message that caused the overflow is also cleared. - */ - data object ClearHistory : OverflowStrategy - - /** - * The smart default. Natively shifts the KV cache to the left, dropping - * the oldest messages while perfectly preserving the system prompt. - * - * This allows for "infinite" feeling conversations by always keeping the - * most recent context. - * - * @property dropTokens The number of tokens to clear from the oldest part - * of the history when an overflow occurs. Default: 500. - */ - data class RollingWindow(val dropTokens: Int = 500) : OverflowStrategy -} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/model/PromptFormat.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/model/PromptFormat.kt deleted file mode 100644 index 6ac347a..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/model/PromptFormat.kt +++ /dev/null @@ -1,36 +0,0 @@ -package com.suhel.llamabro.sdk.model - -/** - * Defines the chat template used to wrap user and assistant messages. - * - * LLMs are trained with specific special tokens to distinguish between different - * roles (System, User, Assistant). Providing the correct format is essential - * for the model to follow instructions and maintain conversation flow. - * - * Use [PromptFormats] to select a pre-defined template for popular models. - * - * @property systemPrefix The token(s) that start a system instruction. - * @property systemSuffix The token(s) that end a system instruction. - * @property userPrefix The token(s) that start a user message. - * @property userSuffix The token(s) that end a user message. - * @property assistantPrefix The token(s) that start an assistant response. - * @property assistantSuffix The token(s) that end an assistant response (often used as a stop sequence). - * @property bos The Beginning of Sentence token (e.g., "" or "<|begin_of_text|>"). - * Inserted once at the very start of a session. - * @property eos The End of Sentence token (e.g., "" or "<|end_of_text|>"). - * @property thinkStart The token(s) that start a thinking block. - * @property thinkEnd The token(s) that end a thinking block. - */ -data class PromptFormat( - val systemPrefix: String, - val systemSuffix: String, - val userPrefix: String, - val userSuffix: String, - val assistantPrefix: String, - val assistantSuffix: String, - val bos: String? = null, - val eos: String? = null, - val thinkStart: String = "", - val thinkEnd: String = "", - val stopStrings: List = emptyList(), -) diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/model/PromptFormats.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/model/PromptFormats.kt deleted file mode 100644 index 6abc7cc..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/model/PromptFormats.kt +++ /dev/null @@ -1,74 +0,0 @@ -package com.suhel.llamabro.sdk.model - -/** - * Pre-defined [PromptFormat] templates for popular AI model families. - * - * Each model is trained with a specific way of separating roles (System, User, Assistant). - * Using the wrong format will lead to poor model performance, hallucinations, - * or the model failing to recognize when it should stop generating. - * - * If your model is not listed here, consult the model's documentation on HuggingFace - * to determine the correct chat template and create a custom [PromptFormat]. - */ -object PromptFormats { - /** - * The ChatML format used by models like OpenAI (historically), Qwen, and many others. - * Uses `<|im_start|>` and `<|im_end|>` tags. - */ - val ChatML = PromptFormat( - systemPrefix = "<|im_start|>system\n", - systemSuffix = "<|im_end|>", - userPrefix = "<|im_start|>user\n", - userSuffix = "<|im_end|>", - assistantPrefix = "<|im_start|>assistant\n", - assistantSuffix = "<|im_end|>", - stopStrings = listOf("<|im_start|>"), - ) - - /** - * The official prompt format for Meta's Llama 3 and 3.1 models. - * Uses `<|start_header_id|>` and `<|eot_id|>` tags. - */ - val Llama3 = PromptFormat( - bos = "<|begin_of_text|>", - systemPrefix = "<|start_header_id|>system<|end_header_id|>\n\n", - systemSuffix = "<|eot_id|>", - userPrefix = "<|start_header_id|>user<|end_header_id|>\n\n", - userSuffix = "<|eot_id|>", - assistantPrefix = "<|start_header_id|>assistant<|end_header_id|>\n\n", - assistantSuffix = "<|eot_id|>", - stopStrings = listOf("<|start_header_id|>"), - ) - - /** - * The instruction format used by Mistral-7B-Instruct. - * Uses `[INST]` and `[/INST]` tags. Note that Mistral often doesn't have - * an explicit system role in its base template. - */ - val Mistral = PromptFormat( - bos = "", - eos = "", - userPrefix = "[INST] ", - userSuffix = " [/INST]", - assistantPrefix = "", - assistantSuffix = "", - systemPrefix = "", - systemSuffix = "", - stopStrings = listOf("[INST]"), - ) - - /** - * The prompt format for Google's Gemma 3 models. - * Uses `` and `` tags. - */ - val Gemma3 = PromptFormat( - bos = "", - systemPrefix = "system\n", - systemSuffix = "", - userPrefix = "\nuser\n", - userSuffix = "", - assistantPrefix = "\nmodel\n", - assistantSuffix = "", - stopStrings = listOf(""), - ) -} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/model/ResourceState.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/model/ResourceState.kt index 865b6cc..c5a6068 100644 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/model/ResourceState.kt +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/model/ResourceState.kt @@ -29,7 +29,7 @@ sealed interface ResourceState { /** * An error occurred during loading. - * @param error The [LlamaError] describing what went wrong. + * @param result The [LlamaError] describing what went wrong. */ data class Failure(val error: LlamaError) : ResourceState } diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/model/SessionConfig.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/model/SessionConfig.kt deleted file mode 100644 index 1abbde4..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/model/SessionConfig.kt +++ /dev/null @@ -1,25 +0,0 @@ -package com.suhel.llamabro.sdk.model - -/** - * Configuration for creating a [com.suhel.llamabro.sdk.LlamaSession]. - * - * This class bundles together the context size, sampling behavior, and - * resource management strategies for a single inference session. - * - * @property contextSize The total number of tokens (system prompt + history + generation) - * that can be held in memory. Larger values consume more RAM. - * @property overflowStrategy How the session reacts when the [contextSize] is reached. - * Defaults to [OverflowStrategy.RollingWindow]. - * @property inferenceConfig Parameters controlling token selection (sampling) like - * temperature and penalties. - * @property decodeConfig Low-level tuning for the llama.cpp compute loop. - * @property seed The random seed for deterministic generation. Use -1 - * for a random seed on every run. - */ -data class SessionConfig( - val contextSize: Int = 2048, - val overflowStrategy: OverflowStrategy = OverflowStrategy.RollingWindow(), - val inferenceConfig: InferenceConfig = InferenceConfig(), - val decodeConfig: DecodeConfig = DecodeConfig(), - val seed: Int = -1, -) diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/model/TokenGenerationResult.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/model/TokenGenerationResult.kt deleted file mode 100644 index f303360..0000000 --- a/sdk/src/main/java/com/suhel/llamabro/sdk/model/TokenGenerationResult.kt +++ /dev/null @@ -1,6 +0,0 @@ -package com.suhel.llamabro.sdk.model - -data class TokenGenerationResult( - val token: String?, - val isComplete: Boolean, -) diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/models/ChatEvent.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/models/ChatEvent.kt new file mode 100644 index 0000000..1efd3b5 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/models/ChatEvent.kt @@ -0,0 +1,51 @@ +package com.suhel.llamabro.sdk.models + +import com.suhel.llamabro.sdk.toolcall.ToolCall +import com.suhel.llamabro.sdk.toolcall.ToolDefinition +import com.suhel.llamabro.sdk.toolcall.ToolResult +import kotlinx.serialization.Serializable + +@Serializable +sealed interface ChatEvent { + + @Serializable + data class SystemEvent( + val content: String, + val tools: List = emptyList(), + ) : ChatEvent + + @Serializable + data class UserEvent( + val content: String, + val think: Boolean, + ) : ChatEvent + + @Serializable + data class AssistantEvent( + val parts: List, + ) : ChatEvent { + val text: String + get() = parts.filterIsInstance().joinToString("") { it.content } + + val thinkingText: String + get() = parts.filterIsInstance().joinToString("") { it.content } + + val toolCalls: List + get() = parts.filterIsInstance().map { it.call } + + @Serializable + sealed interface Part { + @Serializable + data class TextPart(val content: String) : Part + @Serializable + data class ThinkingPart(val content: String) : Part + @Serializable + data class ToolCallPart(val call: ToolCall) : Part + } + } + + @Serializable + data class ToolResultEvent( + val result: ToolResult, + ) : ChatEvent +} diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/models/CompletionSnapshot.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/models/CompletionSnapshot.kt new file mode 100644 index 0000000..63c2bab --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/models/CompletionSnapshot.kt @@ -0,0 +1,15 @@ +package com.suhel.llamabro.sdk.models + +/** + * A snapshot of the model's completion at a point in time. + * + * Emitted progressively during inference. The final emission has [isComplete] = true. + */ +data class CompletionSnapshot( + val message: ChatEvent.AssistantEvent, + val isComplete: Boolean, + val isError: Boolean, + val error: String?, + /** Tokens generated per second. Only meaningful when [isComplete] is true. */ + val tokensPerSecond: Float = 0f, +) diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/toolcall/ToolCallDefinition.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/toolcall/ToolCallDefinition.kt new file mode 100644 index 0000000..33b5730 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/toolcall/ToolCallDefinition.kt @@ -0,0 +1,10 @@ +package com.suhel.llamabro.sdk.toolcall + +import com.suhel.llamabro.sdk.chat.pipeline.ToolCallMarker + +data class ToolCallDefinition( + val marker: ToolCallMarker, + val callParser: (String) -> ToolCall, + val callSerializer: (ToolCall) -> String, + val definitionFormatter: (List) -> String +) diff --git a/sdk/src/main/java/com/suhel/llamabro/sdk/toolcall/ToolModels.kt b/sdk/src/main/java/com/suhel/llamabro/sdk/toolcall/ToolModels.kt new file mode 100644 index 0000000..ab7cee8 --- /dev/null +++ b/sdk/src/main/java/com/suhel/llamabro/sdk/toolcall/ToolModels.kt @@ -0,0 +1,47 @@ +package com.suhel.llamabro.sdk.toolcall + +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonObject + +@Serializable +data class ToolDefinition( + val name: String, + val description: String, + val parameters: ToolParameters = ToolParameters() +) + +@Serializable +data class ToolParameters( + val properties: Map = emptyMap(), + val required: List = emptyList(), +) + +@Serializable +data class ToolParameter( + val type: Type, + val description: String? = null, + val properties: Map = emptyMap(), + val items: ToolParameter? = null, + val required: List = emptyList(), + val enum: List = emptyList(), + val nullable: Boolean = false, +) + +enum class Type { + STRING, NUMBER, INTEGER, BOOLEAN, OBJECT, ARRAY, +} + +@Serializable +data class ToolCall( + val id: String? = null, + val name: String, + val arguments: JsonObject, +) + +@Serializable +data class ToolResult( + val id: String? = null, + val name: String, + val result: JsonElement, +) diff --git a/sdk/src/test/java/com/suhel/llamabro/sdk/LlamaChatSessionImplTest.kt b/sdk/src/test/java/com/suhel/llamabro/sdk/LlamaChatSessionImplTest.kt deleted file mode 100644 index df060ac..0000000 --- a/sdk/src/test/java/com/suhel/llamabro/sdk/LlamaChatSessionImplTest.kt +++ /dev/null @@ -1,338 +0,0 @@ -package com.suhel.llamabro.sdk - -import com.suhel.llamabro.sdk.internal.LlamaChatSessionImpl -import com.suhel.llamabro.sdk.model.LlamaError -import com.suhel.llamabro.sdk.model.Message -import com.suhel.llamabro.sdk.model.ModelConfig -import com.suhel.llamabro.sdk.model.PromptFormat -import com.suhel.llamabro.sdk.model.PromptFormats -import com.suhel.llamabro.sdk.model.ResourceState -import com.suhel.llamabro.sdk.model.TokenGenerationResult -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.emptyFlow -import kotlinx.coroutines.flow.toList -import kotlinx.coroutines.test.runTest -import org.junit.Assert.assertEquals -import org.junit.Assert.assertFalse -import org.junit.Assert.assertNotNull -import org.junit.Assert.assertNull -import org.junit.Assert.assertTrue -import org.junit.Test - -/** - * Unit tests for [LlamaChatSessionImpl], focusing on token processing, - * thinking block extraction, and prompt formatting. - */ -class LlamaChatSessionImplTest { - - // ── Test double ───────────────────────────────────────────────────────── - - /** - * A fake [LlamaSession] that returns a pre-configured token sequence. - * Used to simulate model output without running native code. - */ - private class FakeSession( - private val tokens: List, - override val modelConfig: ModelConfig = ModelConfig("fake.gguf", PASSTHROUGH), - private val shouldThrow: LlamaError? = null - ) : LlamaSession { - private var index = 0 - val prompts = mutableListOf() - var cleared = false - var aborted = false - - override suspend fun setSystemPrompt(text: String, addSpecial: Boolean) { - prompts.add(text) - } - - override suspend fun ingestPrompt(prompt: String, addSpecial: Boolean) { - prompts.add(prompt) - } - - override suspend fun generate(): TokenGenerationResult { - shouldThrow?.let { throw it } - if (tokens.isEmpty()) return TokenGenerationResult(null, true) - - val token = tokens[index] - index++ - return TokenGenerationResult(token, index == tokens.size) - } - - override suspend fun clear() { - cleared = true - index = 0 - } - - override fun abort() { - aborted = true - } - - override fun close() {} - - override suspend fun createChatSession(systemPrompt: String): LlamaChatSession { - return LlamaChatSessionImpl(this, systemPrompt) - } - - override fun createChatSessionFlow(systemPrompt: String): Flow> { - return emptyFlow() - } - } - - companion object { - /** Transparent format — no wrapping, so tests focus on core behaviour. */ - private val PASSTHROUGH = PromptFormat( - systemPrefix = "", - systemSuffix = "", - userPrefix = "", - userSuffix = "", - assistantPrefix = "", - assistantSuffix = "", - ) - } - - // ── Performance Metrics ───────────────────────────────────────────────── - - @Test - fun `tokensPerSecond is non-zero for non-empty generation`() = runTest { - val session = LlamaChatSessionImpl(FakeSession(listOf("Hello", " world")), "") - val generations = session.completion("Hi").toList() - val last = generations.last() - - assertTrue("isComplete should be true", last.isComplete) - assertNotNull("tokensPerSecond should not be null", last.tokensPerSecond) - assertTrue( - "tokensPerSecond should be > 0, was ${last.tokensPerSecond}", - (last.tokensPerSecond ?: 0f) > 0f - ) - } - - @Test - fun `tokensPerSecond is zero for empty generation`() = runTest { - val session = LlamaChatSessionImpl(FakeSession(emptyList()), "") - val generations = session.completion("Hi").toList() - val last = generations.last() - - assertTrue(last.isComplete) - } - - // ── Content accumulation ──────────────────────────────────────────────── - - @Test - fun `content text accumulates across tokens`() = runTest { - val session = LlamaChatSessionImpl(FakeSession(listOf("Hello", " ", "world")), "") - val generations = session.completion("Hi").toList() - val last = generations.last() - - assertEquals("Hello world", last.contentText) - assertTrue(last.isComplete) - } - - @Test - fun `scan produces progressive snapshots`() = runTest { - val session = LlamaChatSessionImpl(FakeSession(listOf("A", "B")), "") - val generations = session.completion("Hi").toList() - - // Emissions: A, AB, metrics (isComplete=true) - assertTrue(generations.size >= 2) - assertEquals("A", generations[0].contentText) - assertEquals("AB", generations[1].contentText) - } - - // ── Thinking tag classification ───────────────────────────────────────── - - @Test - fun `thinking tags separate thinking from content`() = runTest { - val tokens = listOf("", "reasoning", "", "answer") - val session = LlamaChatSessionImpl(FakeSession(tokens), "") - val generations = session.completion("Hi").toList() - val last = generations.last() - - assertEquals("reasoning", last.thinkingText) - assertEquals("answer", last.contentText) - assertTrue(last.isComplete) - } - - @Test - fun `thinking tags split across token boundaries`() = runTest { - // "" split across two tokens: "" - val tokens = listOf("", "deep thought", "", "visible") - val session = LlamaChatSessionImpl(FakeSession(tokens), "") - val generations = session.completion("Hi").toList() - val last = generations.last() - - assertEquals("deep thought", last.thinkingText) - assertEquals("visible", last.contentText) - assertTrue(last.isComplete) - } - - @Test - fun `content before think tag is emitted as content`() = runTest { - val tokens = listOf("preamble", "", "thought", "", "answer") - val session = LlamaChatSessionImpl(FakeSession(tokens), "") - val generations = session.completion("Hi").toList() - val last = generations.last() - - assertEquals("thought", last.thinkingText) - assertEquals("preambleanswer", last.contentText) - } - - @Test - fun `no thinking tags means all content`() = runTest { - val tokens = listOf("just", " content") - val session = LlamaChatSessionImpl(FakeSession(tokens), "") - val generations = session.completion("Hi").toList() - val last = generations.last() - - assertNull(last.thinkingText) - assertEquals("just content", last.contentText) - } - - // ── Blank line / empty message fixes ─────────────────────────────────── - - @Test - fun `trailing newlines are trimmed in final output`() = runTest { - val session = LlamaChatSessionImpl(FakeSession(listOf("Hello", "\n")), "") - val generations = session.completion("Hi").toList() - val last = generations.last() - - assertEquals("Hello", last.contentText) - assertTrue(last.isComplete) - } - - @Test - fun `leading newlines after thinking block are trimmed`() = runTest { - val tokens = listOf("", "r", "", "\n\n", "answer") - val session = LlamaChatSessionImpl(FakeSession(tokens), "") - val generations = session.completion("Hi").toList() - val last = generations.last() - - assertEquals("answer", last.contentText) - assertEquals("r", last.thinkingText) - } - - @Test - fun `thinking-only output keeps contentText null`() = runTest { - val tokens = listOf("", "thought", "") - val session = LlamaChatSessionImpl(FakeSession(tokens), "") - val generations = session.completion("Hi").toList() - val last = generations.last() - - assertNull(last.contentText) - assertEquals("thought", last.thinkingText) - } - - @Test - fun `whitespace-only content results in null contentText`() = runTest { - val session = LlamaChatSessionImpl(FakeSession(listOf("\n", "\n")), "") - val generations = session.completion("Hi").toList() - val last = generations.last() - - assertNull(last.contentText) - assertTrue(last.isComplete) - } - - @Test - fun `empty generation keeps both fields null`() = runTest { - val session = LlamaChatSessionImpl(FakeSession(emptyList()), "") - val generations = session.completion("Hi").toList() - val last = generations.last() - - assertNull(last.contentText) - assertNull(last.thinkingText) - assertTrue(last.isComplete) - } - - // ── Prompt formatting and assistant turn boundaries ─────────────────── - - @Test - fun `chat sends formatted user message with assistant prefix`() = runTest { - val modelConfig = ModelConfig("fake.gguf", PromptFormats.ChatML) - val fake = FakeSession(listOf("Hello"), modelConfig) - val session = LlamaChatSessionImpl(fake, "") - session.completion("Hi").toList() - - // Only prompt: user turn + assistant turn opening (no Kotlin-side closing) - assertEquals( - "<|im_start|>user\nHi<|im_end|><|im_start|>assistant\n", - fake.prompts[0] - ) - assertEquals("No Kotlin-side turn closing should occur", 1, fake.prompts.size) - } - - @Test - fun `content is not affected by manually provided eog tokens`() = runTest { - // Since the C++ engine already consumes EOG, we simulate that here by - // NOT including it in the tokens returned by FakeSession, as it would - // have been filtered out by the real session.cpp. - val modelConfig = ModelConfig("fake.gguf", PromptFormats.ChatML) - val fake = FakeSession(listOf("Hello world"), modelConfig) - val session = LlamaChatSessionImpl(fake, "") - val last = session.completion("Hi").toList().last() - - assertEquals("Hello world", last.contentText) - } - - @Test - fun `empty prefix and suffix do not affect content`() = runTest { - val modelConfig = ModelConfig("fake.gguf", PromptFormats.Mistral) - val fake = FakeSession(listOf("Hello"), modelConfig) - val session = LlamaChatSessionImpl(fake, "") - val last = session.completion("Hi").toList().last() - - assertEquals("Hello", last.contentText) - } - - @Test - fun `loadHistory formats each message with the prompt template`() = runTest { - val modelConfig = ModelConfig("fake.gguf", PromptFormats.ChatML) - val fake = FakeSession(emptyList(), modelConfig) - val session = LlamaChatSessionImpl(fake, "") - - session.loadHistory( - listOf( - Message.User("hello"), - Message.Assistant("hi there"), - ) - ) - - assertEquals("<|im_start|>user\nhello<|im_end|>", fake.prompts[0]) - assertEquals("<|im_start|>assistant\nhi there<|im_end|>", fake.prompts[1]) - } - - // ── Reset and Error Handling ─────────────────────────────────────────── - - @Test - fun `reset delegates to session clear`() = runTest { - val fake = FakeSession(emptyList()) - val session = LlamaChatSessionImpl(fake, "") - session.reset() - assertTrue(fake.cleared) - } - - @Test - fun `cancelled error emits interrupted state`() = runTest { - val fake = FakeSession( - tokens = emptyList(), - shouldThrow = LlamaError.Cancelled() - ) - val session = LlamaChatSessionImpl(fake, "") - val generations = session.completion("Hi").toList() - - assertTrue(generations.last().isInterrupted) - assertTrue(generations.last().isComplete) - } - - @Test - fun `fatal errors are emitted as a completion with error field`() = runTest { - val fake = FakeSession( - tokens = emptyList(), - shouldThrow = LlamaError.DecodeFailed(1) - ) - val session = LlamaChatSessionImpl(fake, "") - val last = session.completion("Hi").toList().last() - - assertTrue(last.isComplete) - assertFalse(last.isInterrupted) - assertTrue(last.error is LlamaError.DecodeFailed) - } -} diff --git a/sdk/src/test/java/com/suhel/llamabro/sdk/chat/internal/LlamaChatSessionImplTest.kt b/sdk/src/test/java/com/suhel/llamabro/sdk/chat/internal/LlamaChatSessionImplTest.kt new file mode 100644 index 0000000..95b5810 --- /dev/null +++ b/sdk/src/test/java/com/suhel/llamabro/sdk/chat/internal/LlamaChatSessionImplTest.kt @@ -0,0 +1,228 @@ +package com.suhel.llamabro.sdk.chat.internal + +import com.suhel.llamabro.sdk.chat.LlamaChatSession +import com.suhel.llamabro.sdk.chat.pipeline.ThinkingMarker +import com.suhel.llamabro.sdk.config.ModelDefinition +import com.suhel.llamabro.sdk.config.ModelLoadConfig +import com.suhel.llamabro.sdk.engine.LlamaSession +import com.suhel.llamabro.sdk.engine.TokenGenerationResult +import com.suhel.llamabro.sdk.engine.TokenGenerationResultCode +import com.suhel.llamabro.sdk.format.PromptFormats +import com.suhel.llamabro.sdk.model.ResourceState +import com.suhel.llamabro.sdk.models.ChatEvent +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.test.runTest +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Test + +/** + * Unit tests for [LlamaChatSessionImpl], exercising the full lexing → semantic → snapshot + * pipeline against a fake [LlamaSession] that emits pre-scripted token sequences. + */ +class LlamaChatSessionImplTest { + + // ── Test double ───────────────────────────────────────────────────────── + + /** + * A deterministic fake [LlamaSession] that replays a fixed token list, + * enabling precise, reproducible pipeline tests without native code. + */ + private class FakeSession( + private val tokens: List, + override val modelDefinition: ModelDefinition = noFeaturesModel() + ) : LlamaSession { + + val addedPrompts = mutableListOf() + + override suspend fun setPrefixedPrompt(text: String) = Unit + override suspend fun addPrompt(prompt: String) { addedPrompts += prompt } + override suspend fun generate(): TokenGenerationResult = + error("Use generateFlow() in tests") + override suspend fun clear() = Unit + override fun abort() = Unit + override fun close() = Unit + + override fun generateFlow(): Flow = flow { + tokens.forEachIndexed { i, token -> + emit( + TokenGenerationResult( + token = token, + resultCode = TokenGenerationResultCode.OK, + isComplete = i == tokens.lastIndex, + ) + ) + } + // Edge case: emit completion signal on empty token list + if (tokens.isEmpty()) { + emit(TokenGenerationResult(null, TokenGenerationResultCode.OK, isComplete = true)) + } + } + + override suspend fun createChatSession(systemPrompt: String): LlamaChatSession = + LlamaChatSessionImpl(this, systemPrompt) + + override fun createChatSessionFlow(systemPrompt: String): Flow> = + flow { emit(ResourceState.Success(createChatSession(systemPrompt))) } + } + + companion object { + private fun noFeaturesModel() = ModelDefinition( + loadConfig = ModelLoadConfig(path = "fake.gguf"), + promptFormat = PromptFormats.CHAT_ML, + ) + + private fun thinkingModel() = ModelDefinition( + loadConfig = ModelLoadConfig(path = "fake.gguf"), + promptFormat = PromptFormats.CHAT_ML, + features = listOf(ThinkingMarker(open = "", close = "")) + ) + } + + // ── Snapshot accumulation ──────────────────────────────────────────────── + + @Test + fun `text tokens accumulate across intermediary snapshots`() = runTest { + val session = LlamaChatSessionImpl(FakeSession(listOf("Hello", " ", "world")), "") + val snapshots = session.completion(ChatEvent.UserEvent("Hi", think = false)).toList() + + val last = snapshots.last() + assertEquals("Hello world", last.message.text) + assertTrue(last.isComplete) + assertFalse(last.isError) + } + + @Test + fun `snapshots are emitted progressively — one per semantic chunk`() = runTest { + val session = LlamaChatSessionImpl(FakeSession(listOf("A", "B")), "") + val snapshots = session.completion(ChatEvent.UserEvent("Hi", think = false)).toList() + + // Each text chunk emits an intermediate snapshot, plus a final one + assertTrue("Expected at least 3 snapshots", snapshots.size >= 3) + assertEquals("A", snapshots[0].message.text) + assertEquals("AB", snapshots[1].message.text) + assertTrue(snapshots.last().isComplete) + } + + @Test + fun `empty token list produces a single complete snapshot with empty content`() = runTest { + val session = LlamaChatSessionImpl(FakeSession(emptyList()), "") + val snapshots = session.completion(ChatEvent.UserEvent("Hi", think = false)).toList() + + assertEquals(1, snapshots.size) + assertTrue(snapshots[0].isComplete) + assertEquals("", snapshots[0].message.text) + } + + // ── Thinking tag handling ──────────────────────────────────────────────── + + @Test + fun `thinking tags correctly partition text into thinking and content parts`() = runTest { + val fake = FakeSession( + tokens = listOf("", "reasoning", "", "answer"), + modelDefinition = thinkingModel() + ) + val session = LlamaChatSessionImpl(fake, "") + val last = session.completion(ChatEvent.UserEvent("Hi", think = false)).toList().last() + + assertEquals("reasoning", last.message.thinkingText) + assertEquals("answer", last.message.text) + } + + @Test + fun `thinking tag split across token boundaries is correctly assembled`() = runTest { + val fake = FakeSession( + tokens = listOf("", "deep thought", "", "visible"), + modelDefinition = thinkingModel() + ) + val session = LlamaChatSessionImpl(fake, "") + val last = session.completion(ChatEvent.UserEvent("Hi", think = false)).toList().last() + + assertEquals("deep thought", last.message.thinkingText) + assertEquals("visible", last.message.text) + } + + @Test + fun `text before thinking tag is classified as content`() = runTest { + val fake = FakeSession( + tokens = listOf("preamble", "", "thought", "", "answer"), + modelDefinition = thinkingModel() + ) + val session = LlamaChatSessionImpl(fake, "") + val last = session.completion(ChatEvent.UserEvent("Hi", think = false)).toList().last() + + assertEquals("thought", last.message.thinkingText) + assertEquals("preambleanswer", last.message.text) + } + + @Test + fun `thinking-only output has empty text`() = runTest { + val fake = FakeSession( + tokens = listOf("", "thought", ""), + modelDefinition = thinkingModel() + ) + val session = LlamaChatSessionImpl(fake, "") + val last = session.completion(ChatEvent.UserEvent("Hi", think = false)).toList().last() + + assertEquals("thought", last.message.thinkingText) + assertEquals("", last.message.text) + } + + // ── Prompt forwarding ──────────────────────────────────────────────────── + + @Test + fun `completion adds a formatted user prompt to the session`() = runTest { + val fake = FakeSession(listOf("ok")) + val session = LlamaChatSessionImpl(fake, "") + session.completion(ChatEvent.UserEvent("Hi", think = false)).toList() + + // ChatML user prompt + assistant prefix should be the first prompt added + assertTrue("Expected a user prompt to be added", fake.addedPrompts.isNotEmpty()) + assertTrue( + "Prompt should start with ChatML user prefix", + fake.addedPrompts[0].startsWith("<|im_start|>user\n") + ) + } + + // ── History replay ──────────────────────────────────────────────────────── + + @Test + fun `feedHistory adds formatted prompts for each history event`() = runTest { + val fake = FakeSession(emptyList()) + val session = LlamaChatSessionImpl(fake, "") + session.feedHistory( + listOf( + ChatEvent.UserEvent("hello", think = false), + ChatEvent.AssistantEvent(listOf(ChatEvent.AssistantEvent.Part.TextPart("hi there"))), + ) + ) + + assertEquals(2, fake.addedPrompts.size) + assertTrue(fake.addedPrompts[0].contains("hello")) + assertTrue(fake.addedPrompts[1].contains("hi there")) + } + + // ── Tokens per second ──────────────────────────────────────────────────── + + @Test + fun `tokensPerSecond is positive for non-empty generation`() = runTest { + val session = LlamaChatSessionImpl(FakeSession(listOf("Hello", " ", "world")), "") + val last = session.completion(ChatEvent.UserEvent("Hi", think = false)).toList().last() + + assertTrue(last.isComplete) + // In unit tests the wall clock is too fast to measure precisely, so just confirm non-negative + assertTrue(last.tokensPerSecond >= 0f) + } + + @Test + fun `tokensPerSecond is zero for empty generation`() = runTest { + val session = LlamaChatSessionImpl(FakeSession(emptyList()), "") + val last = session.completion(ChatEvent.UserEvent("Hi", think = false)).toList().last() + + assertTrue(last.isComplete) + assertEquals(0f, last.tokensPerSecond) + } +} diff --git a/sdk/src/test/java/com/suhel/llamabro/sdk/chat/pipeline/AllocationOptimizedScannerTest.kt b/sdk/src/test/java/com/suhel/llamabro/sdk/chat/pipeline/AllocationOptimizedScannerTest.kt new file mode 100644 index 0000000..7cab1b2 --- /dev/null +++ b/sdk/src/test/java/com/suhel/llamabro/sdk/chat/pipeline/AllocationOptimizedScannerTest.kt @@ -0,0 +1,163 @@ +package com.suhel.llamabro.sdk.chat.pipeline + +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Test + +/** + * Unit tests for [AllocationOptimizedScanner]. + * + * Covers pure-text pass-through, complete tag recognition, tag content extraction, + * and robustness against tags split across multiple token boundaries. + */ +class AllocationOptimizedScannerTest { + + private val thinkMarker = ThinkingMarker(open = "", close = "") + private val toolMarker = ToolCallMarker(open = "", close = "") + + // ── Pure text (no markers configured) ────────────────────────────────── + + @Test + fun `plain text with no markers emits single Text event`() { + val scanner = AllocationOptimizedScanner(emptyList()) + val events = scanner.feed("Hello world") + assertEquals(listOf(LexerEvent.Text("Hello world")), events) + } + + @Test + fun `empty token produces no events`() { + val scanner = AllocationOptimizedScanner(emptyList()) + assertTrue(scanner.feed("").isEmpty()) + } + + @Test + fun `null token on empty buffer produces no events`() { + val scanner = AllocationOptimizedScanner(listOf(thinkMarker)) + assertTrue(scanner.feed(null).isEmpty()) + } + + // ── Full tag in a single token ────────────────────────────────────────── + + @Test + fun `full opening tag is emitted as TagOpened with no preceding text`() { + val scanner = AllocationOptimizedScanner(listOf(thinkMarker)) + val events = scanner.feed("") + assertEquals(listOf(LexerEvent.TagOpened(thinkMarker)), events) + } + + @Test + fun `text before opening tag is flushed before TagOpened`() { + val scanner = AllocationOptimizedScanner(listOf(thinkMarker)) + val events = scanner.feed("preamble") + assertEquals( + listOf( + LexerEvent.Text("preamble"), + LexerEvent.TagOpened(thinkMarker), + ), + events + ) + } + + @Test + fun `content inside tag is emitted as TagContent`() { + val scanner = AllocationOptimizedScanner(listOf(thinkMarker)) + scanner.feed("") + val events = scanner.feed("deep thought") + assertEquals(listOf(LexerEvent.TagContent(thinkMarker, "deep thought")), events) + } + + @Test + fun `complete round-trip emits opened, content, closed`() { + val scanner = AllocationOptimizedScanner(listOf(thinkMarker)) + val all = scanner.feed("") + scanner.feed("reasoning") + scanner.feed("") + assertEquals( + listOf( + LexerEvent.TagOpened(thinkMarker), + LexerEvent.TagContent(thinkMarker, "reasoning"), + LexerEvent.TagClosed(thinkMarker), + ), + all + ) + } + + @Test + fun `text after closing tag is emitted as Text`() { + val scanner = AllocationOptimizedScanner(listOf(thinkMarker)) + val all = scanner.feed("r") + scanner.feed("answer") + assertEquals(LexerEvent.Text("answer"), all.last()) + } + + // ── Tags split across token boundaries ──────────────────────────────────── + + @Test + fun `opening tag split across two tokens is correctly recognized`() { + val scanner = AllocationOptimizedScanner(listOf(thinkMarker)) + val a = scanner.feed("") // completes the tag + assertTrue("No events on partial", a.isEmpty()) + assertEquals(listOf(LexerEvent.TagOpened(thinkMarker)), b) + } + + @Test + fun `closing tag split across two tokens is correctly recognized`() { + val scanner = AllocationOptimizedScanner(listOf(thinkMarker)) + scanner.feed("") + scanner.feed("content") + val a = scanner.feed("") // completes it + // 'a' may include a TagContent for the partial prefix of the closing tag held back — it should be empty + assertTrue("No events on partial close", a.isEmpty()) + assertTrue("Closed event in second feed", b.any { it is LexerEvent.TagClosed }) + } + + @Test + fun `tag split into single character tokens is correctly assembled`() { + val scanner = AllocationOptimizedScanner(listOf(thinkMarker)) + // Feed "" one char at a time + val events = "".map { scanner.feed(it.toString()) }.flatten() + assertEquals(listOf(LexerEvent.TagOpened(thinkMarker)), events) + } + + // ── Multiple markers ─────────────────────────────────────────────────── + + @Test + fun `think and tool markers are both recognized independently`() { + val markers = listOf(thinkMarker, toolMarker) + val scanner = AllocationOptimizedScanner(markers) + val events = scanner.feed("rfn") + + assertTrue(events.any { it is LexerEvent.TagOpened && it.marker == thinkMarker }) + assertTrue(events.any { it is LexerEvent.TagOpened && it.marker == toolMarker }) + assertTrue(events.any { it is LexerEvent.TagClosed && it.marker == thinkMarker }) + assertTrue(events.any { it is LexerEvent.TagClosed && it.marker == toolMarker }) + } + + // ── Flush on stream end ─────────────────────────────────────────────────── + + @Test + fun `null token with no partial tag in buffer produces no events`() { + val scanner = AllocationOptimizedScanner(listOf(thinkMarker)) + // Feed a complete text token — it is immediately flushed + val textEvents = scanner.feed("Hello") + assertEquals(listOf(LexerEvent.Text("Hello")), textEvents) + + // Feeding null after a fully-flushed buffer should produce nothing + val endEvents = scanner.feed(null) + assertTrue("No events expected after null on empty buffer", endEvents.isEmpty()) + } + + @Test + fun `null token with partial tag in buffer does not speculatively flush it as text`() { + // The scanner is conservative: if a buffer tail could still be the start of a registered + // tag, it will NOT emit it speculatively. The LLM stream is expected to disambiguate it + // by emitting more tokens. Null does not break this invariant. + val scanner = AllocationOptimizedScanner(listOf(thinkMarker)) + // "" — scanner holds it + val partial = scanner.feed("system\nYou are helpful.<|im_end|>\n", + formatter.formatTurn(ChatEvent.SystemEvent("You are helpful.")) + ) + } + + // ── User event ────────────────────────────────────────────────────────── + + @Test + fun `user event wraps content and emits assistant prefix when configured`() { + val formatter = PromptFormatter(modelWith(PromptFormats.CHAT_ML)) + val result = formatter.formatTurn(ChatEvent.UserEvent("Hello", think = false)) + assertEquals( + "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n", + result + ) + } + + @Test + fun `user event with think=true appends thinking marker open tag`() { + val model = modelWithThinking(PromptFormats.CHAT_ML) + val formatter = PromptFormatter(model) + val result = formatter.formatTurn(ChatEvent.UserEvent("Solve this", think = true)) + // Should end with assistantPrefix + thinking open tag + val expected = "<|im_start|>user\nSolve this<|im_end|>\n<|im_start|>assistant\n\n" + assertEquals(expected, result) + } + + @Test + fun `user event with think=false does not inject thinking marker`() { + val model = modelWithThinking(PromptFormats.CHAT_ML) + val formatter = PromptFormatter(model) + val result = formatter.formatTurn(ChatEvent.UserEvent("Hello", think = false)) + assert(!result.contains("")) { + "Expected no thinking tag, but got: $result" + } + } + + // ── Assistant event ────────────────────────────────────────────────────── + + @Test + fun `assistant event with single text part formats correctly in ChatML`() { + val formatter = PromptFormatter(modelWith(PromptFormats.CHAT_ML)) + val event = ChatEvent.AssistantEvent(parts = listOf(ChatEvent.AssistantEvent.Part.TextPart("Hi there"))) + val result = formatter.formatTurn(event) + assertEquals("Hi there<|im_end|>\n", result) + } + + @Test + fun `assistant event with no text parts emits only endOfTurn`() { + val formatter = PromptFormatter(modelWith(PromptFormats.CHAT_ML)) + val event = ChatEvent.AssistantEvent(parts = emptyList()) + val result = formatter.formatTurn(event) + assertEquals("<|im_end|>\n", result) + } + + // ── Llama3 format ──────────────────────────────────────────────────────── + + @Test + fun `Llama3 user event wraps with correct header tokens`() { + val formatter = PromptFormatter(modelWith(PromptFormats.LLAMA_3)) + val result = formatter.formatTurn(ChatEvent.UserEvent("Hello", think = false)) + assertEquals( + "<|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + result + ) + } + + // ── Nemotron format ────────────────────────────────────────────────────── + + @Test + fun `Nemotron system event uses extra_id_0 sentinel`() { + val formatter = PromptFormatter(modelWith(PromptFormats.NEMOTRON)) + val result = formatter.formatTurn(ChatEvent.SystemEvent("You are helpful.")) + assertEquals("System\nYou are helpful.\n", result) + } + + @Test + fun `Nemotron user event uses extra_id_1 sentinel`() { + val formatter = PromptFormatter(modelWith(PromptFormats.NEMOTRON)) + val result = formatter.formatTurn(ChatEvent.UserEvent("Hello", think = false)) + assertEquals("User\nHello\nAssistant\n", result) + } +} diff --git a/sdk/src/test/java/com/suhel/llamabro/sdk/internal/NativeErrorMapperTest.kt b/sdk/src/test/java/com/suhel/llamabro/sdk/internal/NativeErrorMapperTest.kt deleted file mode 100644 index 9a3096f..0000000 --- a/sdk/src/test/java/com/suhel/llamabro/sdk/internal/NativeErrorMapperTest.kt +++ /dev/null @@ -1,84 +0,0 @@ -package com.suhel.llamabro.sdk.internal - -import com.suhel.llamabro.sdk.model.LlamaError -import org.junit.Assert.assertEquals -import org.junit.Assert.assertNull -import org.junit.Assert.assertTrue -import org.junit.Test - -class NativeErrorMapperTest { - - @Test - fun `code 1 maps to ModelNotFound`() { - val error = mapNativeError(RuntimeException("1:/data/model.gguf")) - assertTrue(error is LlamaError.ModelNotFound) - assertEquals("/data/model.gguf", (error as LlamaError.ModelNotFound).path) - } - - @Test - fun `code 2 maps to ModelLoadFailed`() { - val error = mapNativeError(RuntimeException("2:/data/model.gguf")) - assertTrue(error is LlamaError.ModelLoadFailed) - assertEquals("/data/model.gguf", (error as LlamaError.ModelLoadFailed).path) - } - - @Test - fun `code 3 maps to BackendLoadFailed`() { - val error = mapNativeError(RuntimeException("3:ggml-cpu")) - assertTrue(error is LlamaError.BackendLoadFailed) - assertEquals("ggml-cpu", (error as LlamaError.BackendLoadFailed).backendName) - } - - @Test - fun `code 10 maps to ContextInitFailed`() { - val error = mapNativeError(RuntimeException("10:")) - assertTrue(error is LlamaError.ContextInitFailed) - } - - @Test - fun `code 11 maps to ContextOverflow`() { - val error = mapNativeError(RuntimeException("11:")) - assertTrue(error is LlamaError.ContextOverflow) - } - - @Test - fun `code 12 maps to DecodeFailed with decode error code`() { - val error = mapNativeError(RuntimeException("12:42")) - assertTrue(error is LlamaError.DecodeFailed) - assertEquals(42, (error as LlamaError.DecodeFailed).code) - } - - @Test - fun `code 12 with non-numeric detail falls back to -1`() { - val error = mapNativeError(RuntimeException("12:not_a_number")) - assertTrue(error is LlamaError.DecodeFailed) - assertEquals(-1, (error as LlamaError.DecodeFailed).code) - } - - @Test - fun `unknown code maps to NativeException`() { - val error = mapNativeError(RuntimeException("999:something went wrong")) - assertTrue(error is LlamaError.NativeException) - assertEquals("something went wrong", (error as LlamaError.NativeException).nativeMessage) - } - - @Test - fun `null message maps to NativeException with Unknown`() { - val error = mapNativeError(RuntimeException(null as String?)) - assertTrue(error is LlamaError.NativeException) - assertEquals("Native error: Unknown native error", error.message) - } - - @Test - fun `message without colon maps to NativeException`() { - val error = mapNativeError(RuntimeException("no colon here")) - assertTrue(error is LlamaError.NativeException) - assertEquals("no colon here", (error as LlamaError.NativeException).nativeMessage) - } - - @Test - fun `empty message maps to NativeException`() { - val error = mapNativeError(RuntimeException("")) - assertTrue(error is LlamaError.NativeException) - } -} diff --git a/sdk/src/test/java/com/suhel/llamabro/sdk/internal/TokenStreamParserTest.kt b/sdk/src/test/java/com/suhel/llamabro/sdk/internal/TokenStreamParserTest.kt deleted file mode 100644 index 6cdcd7f..0000000 --- a/sdk/src/test/java/com/suhel/llamabro/sdk/internal/TokenStreamParserTest.kt +++ /dev/null @@ -1,253 +0,0 @@ -package com.suhel.llamabro.sdk.internal - -import org.junit.Assert.assertEquals -import org.junit.Assert.assertFalse -import org.junit.Assert.assertNull -import org.junit.Assert.assertTrue -import org.junit.Test - -class TokenStreamParserTest { - - // ── Helpers ────────────────────────────────────────────────────────────── - - private fun TokenStreamParser.feed(vararg tokens: String): Pair { - val content = StringBuilder() - val thinking = StringBuilder() - for (token in tokens) process(token, content, thinking) - flush(content, thinking) - return content.ifEmpty { null }?.toString() to thinking.ifEmpty { null }?.toString() - } - - // ── Basic routing ──────────────────────────────────────────────────────── - - @Test - fun `empty stream produces nulls`() { - val parser = TokenStreamParser() - val (content, thinking) = parser.feed() - assertNull(content) - assertNull(thinking) - } - - @Test - fun `pure content with no tags`() { - val parser = TokenStreamParser() - val (content, thinking) = parser.feed("Hello", " world") - assertEquals("Hello world", content) - assertNull(thinking) - } - - @Test - fun `pure thinking block`() { - val parser = TokenStreamParser() - val (content, thinking) = parser.feed("", "reasoning", "") - assertNull(content) - assertEquals("reasoning", thinking) - } - - @Test - fun `thinking then content`() { - val parser = TokenStreamParser() - val (content, thinking) = parser.feed("", "thought", "", "answer") - assertEquals("answer", content) - assertEquals("thought", thinking) - } - - @Test - fun `content before thinking block`() { - val parser = TokenStreamParser() - val (content, thinking) = parser.feed("pre", "", "mid", "", "post") - assertEquals("prepost", content) - assertEquals("mid", thinking) - } - - // ── Tag splitting across token boundaries ──────────────────────────────── - - @Test - fun `open tag split across two tokens`() { - val parser = TokenStreamParser() - val (content, thinking) = parser.feed("", "thought", "", "answer") - assertEquals("answer", content) - assertEquals("thought", thinking) - } - - @Test - fun `close tag split across two tokens`() { - val parser = TokenStreamParser() - val (content, thinking) = parser.feed("", "thought", "", "answer") - assertEquals("answer", content) - assertEquals("thought", thinking) - } - - @Test - fun `both tags split one character at a time`() { - val parser = TokenStreamParser() - val tokens = "thoughtanswer".map { it.toString() }.toTypedArray() - val (content, thinking) = parser.feed(*tokens) - assertEquals("answer", content) - assertEquals("thought", thinking) - } - - // ── Edge cases ─────────────────────────────────────────────────────────── - - @Test - fun `empty thinking block`() { - val parser = TokenStreamParser() - val (content, thinking) = parser.feed("", "", "answer") - assertEquals("answer", content) - assertNull(thinking) - } - - @Test - fun `multiple thinking blocks are concatenated`() { - val parser = TokenStreamParser() - val (content, thinking) = parser.feed( - "", "first", "", - "between", - "", "second", "", - "end" - ) - assertEquals("betweenend", content) - assertEquals("firstsecond", thinking) - } - - @Test - fun `stream ends mid open tag — flushed to content`() { - val parser = TokenStreamParser() - val content = StringBuilder() - val thinking = StringBuilder() - parser.process("thought", "answer") - assertEquals("answer", content) - assertEquals("deep thought", thinking) - } - - @Test - fun `reset clears prior state`() { - val parser = TokenStreamParser() - parser.feed("", "old") - assertTrue(parser.isThinking) - - parser.reset() - assertFalse(parser.isThinking) - assertFalse(parser.isStopped) - - val (content, _) = parser.feed("fresh") - assertEquals("fresh", content) - } - - // ── Stop strings ───────────────────────────────────────────────────────── - - @Test - fun `stop string halts output and is not emitted`() { - val parser = TokenStreamParser(stopStrings = listOf("[STOP]")) - val (content, _) = parser.feed("hello ", "[STOP]", "ignored") - assertEquals("hello ", content) - assertTrue(parser.isStopped) - } - - @Test - fun `stop string split across tokens`() { - val parser = TokenStreamParser(stopStrings = listOf("[STOP]")) - val (content, _) = parser.feed("hello ", "[ST", "OP]", "ignored") - assertEquals("hello ", content) - assertTrue(parser.isStopped) - } - - @Test - fun `content before stop string is preserved`() { - val parser = TokenStreamParser(stopStrings = listOf("<|im_start|>")) - val (content, _) = parser.feed("answer text", "<|im_start|>", "user\nnext turn") - assertEquals("answer text", content) - assertTrue(parser.isStopped) - } - - @Test - fun `stop string before think tag — stop wins`() { - val parser = TokenStreamParser(stopStrings = listOf("[STOP]")) - val (content, thinking) = parser.feed("before [STOP]ignored") - assertEquals("before ", content) - assertNull(thinking) - assertTrue(parser.isStopped) - } - - @Test - fun `think tag before stop string — think routes normally`() { - val parser = TokenStreamParser(stopStrings = listOf("[STOP]")) - val (content, thinking) = parser.feed("thoughtanswer[STOP]extra") - assertEquals("answer", content) - assertEquals("thought", thinking) - assertTrue(parser.isStopped) - } - - @Test - fun `isStopped makes subsequent process calls no-ops`() { - val parser = TokenStreamParser(stopStrings = listOf("[STOP]")) - val content = StringBuilder() - val thinking = StringBuilder() - parser.process("[STOP]", content, thinking) - parser.process("should not appear", content, thinking) - assertEquals("", content.toString()) - assertTrue(parser.isStopped) - } - - @Test - fun `reset clears isStopped`() { - val parser = TokenStreamParser(stopStrings = listOf("[STOP]")) - parser.feed("[STOP]") - assertTrue(parser.isStopped) - parser.reset() - assertFalse(parser.isStopped) - val (content, _) = parser.feed("fresh") - assertEquals("fresh", content) - } - - @Test - fun `multiple stop strings — earliest wins`() { - val parser = TokenStreamParser(stopStrings = listOf("[B]", "[A]")) - val (content, _) = parser.feed("text[A]extra[B]more") - assertEquals("text", content) - assertTrue(parser.isStopped) - } - - @Test - fun `no stop strings — parser behaves identically to original`() { - val parser = TokenStreamParser(stopStrings = emptyList()) - val (content, thinking) = parser.feed("", "t", "", "c") - assertEquals("c", content) - assertEquals("t", thinking) - assertFalse(parser.isStopped) - } - - // ── Custom think tags ──────────────────────────────────────────────────── - - @Test - fun `custom think tags are respected`() { - val parser = TokenStreamParser(thinkingStart = "", thinkingEnd = "") - val (content, thinking) = parser.feed("", "thought", "", "answer") - assertEquals("answer", content) - assertEquals("thought", thinking) - } -} diff --git a/sdk/src/test/java/com/suhel/llamabro/sdk/util/PrompterTest.kt b/sdk/src/test/java/com/suhel/llamabro/sdk/util/PrompterTest.kt deleted file mode 100644 index 8138510..0000000 --- a/sdk/src/test/java/com/suhel/llamabro/sdk/util/PrompterTest.kt +++ /dev/null @@ -1,173 +0,0 @@ -package com.suhel.llamabro.sdk.util - -import com.suhel.llamabro.sdk.internal.Prompter -import com.suhel.llamabro.sdk.model.Message -import com.suhel.llamabro.sdk.model.PromptFormat -import com.suhel.llamabro.sdk.model.PromptFormats -import org.junit.Assert.assertEquals -import org.junit.Test - -class PrompterTest { - - // ── ChatML ────────────────────────────────────────────────────────────── - - @Test - fun `ChatML wraps user message correctly`() { - val fmt = Prompter(PromptFormats.ChatML) - assertEquals( - "<|im_start|>user\nHello<|im_end|>", - fmt.format(Message.User("Hello")) - ) - } - - @Test - fun `ChatML wraps assistant message correctly`() { - val fmt = Prompter(PromptFormats.ChatML) - assertEquals( - "<|im_start|>assistant\nHi there<|im_end|>", - fmt.format(Message.Assistant("Hi there")) - ) - } - - @Test - fun `ChatML turn lifecycle produces symmetric open and close`() { - val fmt = Prompter(PromptFormats.ChatML) - assertEquals("<|im_start|>assistant\n", fmt.assistantStart()) - assertEquals("<|im_end|>", fmt.assistantEnd()) - // assistantStart() + content + assistantEnd() == assistant(content) - val content = "Hello world" - assertEquals( - fmt.assistant(content), - fmt.assistantStart() + content + fmt.assistantEnd() - ) - } - - @Test - fun `ChatML shouldAddSpecial is true when BOS is null`() { - val fmt = Prompter(PromptFormats.ChatML) - assertEquals(true, fmt.shouldAddSpecial()) - } - - // ── Llama3 ────────────────────────────────────────────────────────────── - - @Test - fun `Llama3 includes BOS in initialization`() { - val fmt = Prompter(PromptFormats.Llama3) - val prompt = fmt.system("Be helpful") - assertEquals( - "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nBe helpful<|eot_id|>", - prompt - ) - } - - @Test - fun `Llama3 wraps user message correctly`() { - val fmt = Prompter(PromptFormats.Llama3) - assertEquals( - "<|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|>", - fmt.format(Message.User("Hello")) - ) - } - - @Test - fun `Llama3 turn lifecycle produces symmetric open and close`() { - val fmt = Prompter(PromptFormats.Llama3) - assertEquals("<|start_header_id|>assistant<|end_header_id|>\n\n", fmt.assistantStart()) - assertEquals("<|eot_id|>", fmt.assistantEnd()) - val content = "Hello" - assertEquals( - fmt.assistant(content), - fmt.assistantStart() + content + fmt.assistantEnd() - ) - } - - @Test - fun `Llama3 shouldAddSpecial is false when BOS is provided`() { - val fmt = Prompter(PromptFormats.Llama3) - assertEquals(false, fmt.shouldAddSpecial()) - } - - // ── Mistral ───────────────────────────────────────────────────────────── - - @Test - fun `Mistral includes BOS and EOS`() { - val fmt = Prompter(PromptFormats.Mistral) - assertEquals("", fmt.bos()) - assertEquals("", fmt.eos()) - - assertEquals( - "Response", - fmt.format(Message.Assistant("Response")) - ) - } - - @Test - fun `Mistral turn lifecycle with empty prefix and suffix`() { - val fmt = Prompter(PromptFormats.Mistral) - assertEquals("", fmt.assistantStart()) - assertEquals("", fmt.assistantEnd()) - val content = "Hello" - assertEquals( - fmt.assistant(content), - fmt.assistantStart() + content + fmt.assistantEnd() - ) - } - - // ── Gemma3 ────────────────────────────────────────────────────────────── - - @Test - fun `Gemma3 includes BOS`() { - val fmt = Prompter(PromptFormats.Gemma3) - assertEquals("", fmt.bos()) - } - - @Test - fun `Gemma3 turn lifecycle produces symmetric open and close`() { - val fmt = Prompter(PromptFormats.Gemma3) - assertEquals("\nmodel\n", fmt.assistantStart()) - assertEquals("", fmt.assistantEnd()) - val content = "Hello" - assertEquals( - fmt.assistant(content), - fmt.assistantStart() + content + fmt.assistantEnd() - ) - } - - // ── Edge cases ────────────────────────────────────────────────────────── - - @Test - fun `custom prompt format with BOS and EOS`() { - val custom = PromptFormat( - bos = "[BOS]", - eos = "[EOS]", - systemPrefix = "[SYS]", systemSuffix = "[/SYS]", - userPrefix = "[U]", userSuffix = "[/U]", - assistantPrefix = "[A]", assistantSuffix = "[/A]", - ) - val fmt = Prompter(custom) - assertEquals("[BOS]", fmt.bos()) - assertEquals("[U]hi[/U]", fmt.format(Message.User("hi"))) - assertEquals("[A]ok[/A][EOS]", fmt.format(Message.Assistant("ok"))) - } - - @Test - fun `shouldAddSpecial depends only on bos, not eos`() { - val eosOnly = PromptFormat( - bos = null, eos = "[EOS]", - systemPrefix = "", systemSuffix = "", - userPrefix = "", userSuffix = "", - assistantPrefix = "", assistantSuffix = "", - ) - // BOS is null, so tokenizer should add its native BOS regardless of eos - assertEquals(true, Prompter(eosOnly).shouldAddSpecial()) - } - - @Test - fun `assistant message with thinking block`() { - val fmt = Prompter(PromptFormats.ChatML) - assertEquals( - "<|im_start|>assistant\n\nreasoning\nanswer<|im_end|>", - fmt.assistant("answer", thinking = "reasoning") - ) - } -}