Skip to content

feat: dual model support for cost optimization #279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion packages/opencode/src/config/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ export namespace Config {
"Model to use in the format of provider/model, eg anthropic/claude-2",
)
.optional(),
lightweight_model: z
.string()
.describe("Lightweight model to use for tasks like window title generation")
.optional(),
provider: z
.record(
ModelsDev.Provider.partial().extend({
Expand Down Expand Up @@ -194,7 +198,7 @@ export namespace Config {
)
await fs.unlink(path.join(Global.Path.config, "config"))
})
.catch(() => {})
.catch(() => { })

return result
})
Expand Down
35 changes: 35 additions & 0 deletions packages/opencode/src/provider/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,41 @@ export namespace Provider {
}
}

export async function getLightweightModel(providerID: string): Promise<{ info: ModelsDev.Model; language: LanguageModel } | null> {
const cfg = await Config.get()

// Check user override
if (cfg.lightweight_model) {
try {
// Parse the lightweight model to get its provider
const { providerID: lightweightProviderID, modelID } = parseModel(cfg.lightweight_model)
return await getModel(lightweightProviderID, modelID)
} catch (e) {
log.warn("Failed to get configured lightweight model", { lightweight_model: cfg.lightweight_model, error: e })
}
}

const providers = await list()
const provider = providers[providerID]
if (!provider) return null

// Select cheapest model whose cost.output <= 4
let selected: { info: ModelsDev.Model; language: LanguageModel } | null = null
for (const model of Object.values(provider.info.models)) {
if (model.cost.output <= 4) {
try {
const m = await getModel(providerID, model.id)
if (!selected || m.info.cost.output < selected.info.cost.output) {
selected = m
}
} catch {
// ignore errors and continue searching
}
}
}
return selected
}

const TOOLS = [
BashTool,
EditTool,
Expand Down
2 changes: 1 addition & 1 deletion packages/opencode/src/provider/transform.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { CoreMessage, LanguageModelV1Prompt } from "ai"
import type { LanguageModelV1Prompt } from "ai"
import { unique } from "remeda"

export namespace ProviderTransform {
Expand Down
110 changes: 83 additions & 27 deletions packages/tui/internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"path/filepath"
"sort"
"strings"

"log/slog"

Expand All @@ -20,24 +19,29 @@ import (
var RootPath string

type App struct {
Info client.AppInfo
Version string
StatePath string
Config *client.ConfigInfo
Client *client.ClientWithResponses
State *config.State
Provider *client.ProviderInfo
Model *client.ModelInfo
Session *client.SessionInfo
Messages []client.MessageInfo
Commands commands.CommandRegistry
Info client.AppInfo
Version string
StatePath string
Config *client.ConfigInfo
Client *client.ClientWithResponses
State *config.State
MainProvider *client.ProviderInfo
MainModel *client.ModelInfo
LightProvider *client.ProviderInfo
LightModel *client.ModelInfo
Session *client.SessionInfo
Messages []client.MessageInfo
Commands commands.CommandRegistry
}

type SessionSelectedMsg = *client.SessionInfo
type ModelSelectedMsg struct {
Provider client.ProviderInfo
Model client.ModelInfo
MainProvider client.ProviderInfo
MainModel client.ModelInfo
LightweightProvider client.ProviderInfo
LightweightModel client.ModelInfo
}

type SessionClearedMsg struct{}
type CompactSessionMsg struct{}
type SendMsg struct {
Expand Down Expand Up @@ -83,9 +87,10 @@ func New(
appState.Theme = *configInfo.Theme
}
if configInfo.Model != nil {
splits := strings.Split(*configInfo.Model, "/")
appState.Provider = splits[0]
appState.Model = strings.Join(splits[1:], "/")
appState.MainProvider, appState.MainModel = util.ParseModel(*configInfo.Model)
}
if configInfo.LightweightModel != nil {
appState.LightProvider, appState.LightModel = util.ParseModel(*configInfo.LightweightModel)
}

// Load themes from all directories
Expand Down Expand Up @@ -158,11 +163,11 @@ func (a *App) InitializeProvider() tea.Cmd {
var currentProvider *client.ProviderInfo
var currentModel *client.ModelInfo
for _, provider := range providers {
if provider.Id == a.State.Provider {
if provider.Id == a.State.MainProvider {
currentProvider = &provider

for _, model := range provider.Models {
if model.Id == a.State.Model {
if model.Id == a.State.MainModel {
currentModel = &model
}
}
Expand All @@ -173,10 +178,40 @@ func (a *App) InitializeProvider() tea.Cmd {
currentModel = defaultModel
}

// Initialize lightweight model based on config or defaults
lightProvider := currentProvider
lightModel := currentModel

if a.State.LightProvider != "" && a.State.LightModel != "" {
lightProviderID, lightModelID := a.State.LightProvider, a.State.LightModel
// Find provider/model
for _, provider := range providers {
if provider.Id == lightProviderID {
lightProvider = &provider
for _, model := range provider.Models {
if model.Id == lightModelID {
lightModel = &model
break
}
}
break
}
}
} else {
// Try to find a default lightweight model for the provider
lightModel = getDefaultLightweightModel(*currentProvider)
if lightModel == nil {
// Fall back to the main model
lightModel = currentModel
}
}

// TODO: handle no provider or model setup, yet
return ModelSelectedMsg{
Provider: *currentProvider,
Model: *currentModel,
MainProvider: *currentProvider,
MainModel: *currentModel,
LightweightProvider: *lightProvider,
LightweightModel: *lightModel,
}
}
}
Expand All @@ -193,6 +228,20 @@ func getDefaultModel(response *client.PostProviderListResponse, provider client.
return nil
}

func getDefaultLightweightModel(provider client.ProviderInfo) *client.ModelInfo {
// Select the cheapest model whose Cost.Output <= 4
var selected *client.ModelInfo
for _, model := range provider.Models {
if model.Cost.Output <= 4 {
if selected == nil || model.Cost.Output < selected.Cost.Output {
tmp := model // create copy to take address of loop variable safely
selected = &tmp
}
}
}
return selected
}

type Attachment struct {
FilePath string
FileName string
Expand Down Expand Up @@ -231,8 +280,8 @@ func (a *App) InitializeProject(ctx context.Context) tea.Cmd {
go func() {
response, err := a.Client.PostSessionInitialize(ctx, client.PostSessionInitializeJSONRequestBody{
SessionID: a.Session.Id,
ProviderID: a.Provider.Id,
ModelID: a.Model.Id,
ProviderID: a.MainProvider.Id,
ModelID: a.MainModel.Id,
})
if err != nil {
slog.Error("Failed to initialize project", "error", err)
Expand All @@ -248,10 +297,17 @@ func (a *App) InitializeProject(ctx context.Context) tea.Cmd {
}

func (a *App) CompactSession(ctx context.Context) tea.Cmd {
providerID := a.MainProvider.Id
modelID := a.MainModel.Id
if a.LightProvider != nil && a.LightModel != nil {
providerID = a.LightProvider.Id
modelID = a.LightModel.Id
}

response, err := a.Client.PostSessionSummarizeWithResponse(ctx, client.PostSessionSummarizeJSONRequestBody{
SessionID: a.Session.Id,
ProviderID: a.Provider.Id,
ModelID: a.Model.Id,
ProviderID: providerID,
ModelID: modelID,
})
if err != nil {
slog.Error("Failed to compact session", "error", err)
Expand Down Expand Up @@ -315,8 +371,8 @@ func (a *App) SendChatMessage(ctx context.Context, text string, attachments []At
response, err := a.Client.PostSessionChat(ctx, client.PostSessionChatJSONRequestBody{
SessionID: a.Session.Id,
Parts: parts,
ProviderID: a.Provider.Id,
ModelID: a.Model.Id,
ProviderID: a.MainProvider.Id,
ModelID: a.MainModel.Id,
})
if err != nil {
slog.Error("Failed to send message", "error", err)
Expand Down
13 changes: 11 additions & 2 deletions packages/tui/internal/components/chat/editor.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,17 @@ func (m *editorComponent) Content() string {
}

model := ""
if m.app.Model != nil {
model = muted(m.app.Provider.Name) + base(" "+m.app.Model.Name)
if m.app.MainModel != nil && m.app.MainProvider != nil {
model = muted(m.app.MainProvider.Name) + base(" "+m.app.MainModel.Name)

// show lightweight model if configured
if m.app.LightModel != nil && m.app.LightProvider != nil {
if m.app.LightProvider.Id == m.app.MainProvider.Id {
model = model + muted(" (⚡"+m.app.LightModel.Name+")")
} else {
model = model + muted(" (⚡"+m.app.LightProvider.Name+"/"+m.app.LightModel.Name+")")
}
}
}

space := m.width - 2 - lipgloss.Width(model) - lipgloss.Width(hint)
Expand Down
Loading