Skip to content
3 changes: 3 additions & 0 deletions internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ func (s *Server) setupRoutes() {
{
v1.GET("/models", s.unifiedModelsHandler(openaiHandlers, claudeCodeHandlers))
v1.POST("/chat/completions", openaiHandlers.ChatCompletions)
// Audio transcription uses a direct multipart passthrough path instead of the JSON translator stack.
v1.POST("/audio/transcriptions", openaiHandlers.AudioTranscriptions)
v1.POST("/completions", openaiHandlers.Completions)
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
Expand All @@ -353,6 +355,7 @@ func (s *Server) setupRoutes() {
"message": "CLI Proxy API Server",
"endpoints": []string{
"POST /v1/chat/completions",
"POST /v1/audio/transcriptions",
"POST /v1/completions",
"GET /v1/models",
},
Expand Down
15 changes: 15 additions & 0 deletions internal/api/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,18 @@ func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
}
}
}

func TestRootEndpointIncludesAudioTranscriptions(t *testing.T) {
server := newTestServer(t)

req := httptest.NewRequest(http.MethodGet, "/", nil)
resp := httptest.NewRecorder()
server.engine.ServeHTTP(resp, req)

if resp.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String())
}
if body := resp.Body.String(); !strings.Contains(body, "POST /v1/audio/transcriptions") {
t.Fatalf("response body missing audio transcription endpoint: %s", body)
}
}
36 changes: 31 additions & 5 deletions internal/registry/model_definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type staticModelsJSON struct {
Vertex []*ModelInfo `json:"vertex"`
GeminiCLI []*ModelInfo `json:"gemini-cli"`
AIStudio []*ModelInfo `json:"aistudio"`
CodexShared []*ModelInfo `json:"codex-shared,omitempty"`
CodexFree []*ModelInfo `json:"codex-free"`
CodexTeam []*ModelInfo `json:"codex-team"`
CodexPlus []*ModelInfo `json:"codex-plus"`
Expand Down Expand Up @@ -50,22 +51,22 @@ func GetAIStudioModels() []*ModelInfo {

// GetCodexFreeModels returns model definitions for the Codex free plan tier.
func GetCodexFreeModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexFree)
return cloneModelInfos(mergeStaticModelSections(getModels().CodexShared, getModels().CodexFree))
}

// GetCodexTeamModels returns model definitions for the Codex team plan tier.
func GetCodexTeamModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexTeam)
return cloneModelInfos(mergeStaticModelSections(getModels().CodexShared, getModels().CodexTeam))
}

// GetCodexPlusModels returns model definitions for the Codex plus plan tier.
func GetCodexPlusModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexPlus)
return cloneModelInfos(mergeStaticModelSections(getModels().CodexShared, getModels().CodexPlus))
}

// GetCodexProModels returns model definitions for the Codex pro plan tier.
func GetCodexProModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexPro)
return cloneModelInfos(mergeStaticModelSections(getModels().CodexShared, getModels().CodexPro))
}

// GetQwenModels returns the standard Qwen model definitions.
Expand Down Expand Up @@ -100,6 +101,31 @@ func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
return out
}

func mergeStaticModelSections(sections ...[]*ModelInfo) []*ModelInfo {
var out []*ModelInfo
seen := make(map[string]struct{})
for _, section := range sections {
for _, model := range section {
if model == nil {
continue
}
modelID := strings.TrimSpace(model.ID)
if modelID == "" {
continue
}
if _, exists := seen[modelID]; exists {
continue
}
seen[modelID] = struct{}{}
out = append(out, model)
}
}
if len(out) == 0 {
return nil
}
return out
}

// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
// It returns nil when the channel is unknown.
//
Expand Down Expand Up @@ -156,7 +182,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
data.Vertex,
data.GeminiCLI,
data.AIStudio,
data.CodexPro,
mergeStaticModelSections(data.CodexShared, data.CodexPro),
data.Qwen,
data.IFlow,
data.Kimi,
Expand Down
131 changes: 131 additions & 0 deletions internal/registry/model_definitions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package registry

import (
"strings"
"testing"
)

func TestCodexTierModelsIncludeSharedTranscriptionModels(t *testing.T) {
testCases := []struct {
name string
models []*ModelInfo
}{
{name: "free", models: GetCodexFreeModels()},
{name: "team", models: GetCodexTeamModels()},
{name: "plus", models: GetCodexPlusModels()},
{name: "pro", models: GetCodexProModels()},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if !modelListContainsID(tc.models, "gpt-4o-mini-transcribe") {
t.Fatalf("%s tier missing gpt-4o-mini-transcribe", tc.name)
}
if !modelListContainsID(tc.models, "whisper-1") {
t.Fatalf("%s tier missing whisper-1", tc.name)
}
})
}
}

func TestLookupStaticModelInfoFindsSharedCodexModels(t *testing.T) {
for _, modelID := range []string{"gpt-4o-mini-transcribe", "whisper-1"} {
info := LookupStaticModelInfo(modelID)
if info == nil {
t.Fatalf("LookupStaticModelInfo(%q) = nil", modelID)
}
if info.ID != modelID {
t.Fatalf("LookupStaticModelInfo(%q).ID = %q, want %q", modelID, info.ID, modelID)
}
}
}

func TestValidateModelsCatalogRejectsCodexSharedOverlap(t *testing.T) {
data := newStaticModelsCatalogFixture()
data.CodexShared = []*ModelInfo{testStaticModel("shared-transcribe")}
data.CodexFree = []*ModelInfo{testStaticModel("shared-transcribe")}

err := validateModelsCatalog(data)
if err == nil || !strings.Contains(err.Error(), `codex-shared overlaps codex-free on model id "shared-transcribe"`) {
t.Fatalf("validateModelsCatalog() error = %v, want overlap error", err)
}
}

func TestDetectChangedProvidersTreatsCodexSharedAndDuplicatedSchemasAsEquivalent(t *testing.T) {
oldData := newStaticModelsCatalogFixture()
oldData.CodexFree = []*ModelInfo{
testStaticModel("free-only"),
testStaticModel("gpt-4o-mini-transcribe"),
testStaticModel("whisper-1"),
}
oldData.CodexTeam = []*ModelInfo{
testStaticModel("team-only"),
testStaticModel("gpt-4o-mini-transcribe"),
testStaticModel("whisper-1"),
}
oldData.CodexPlus = []*ModelInfo{
testStaticModel("plus-only"),
testStaticModel("gpt-4o-mini-transcribe"),
testStaticModel("whisper-1"),
}
oldData.CodexPro = []*ModelInfo{
testStaticModel("pro-only"),
testStaticModel("gpt-4o-mini-transcribe"),
testStaticModel("whisper-1"),
}

newData := newStaticModelsCatalogFixture()
newData.CodexShared = []*ModelInfo{
testStaticModel("gpt-4o-mini-transcribe"),
testStaticModel("whisper-1"),
}
newData.CodexFree = []*ModelInfo{testStaticModel("free-only")}
newData.CodexTeam = []*ModelInfo{testStaticModel("team-only")}
newData.CodexPlus = []*ModelInfo{testStaticModel("plus-only")}
newData.CodexPro = []*ModelInfo{testStaticModel("pro-only")}

changed := detectChangedProviders(oldData, newData)
if stringSliceContains(changed, "codex") {
t.Fatalf("detectChangedProviders() = %v, want codex unchanged", changed)
}
}

func newStaticModelsCatalogFixture() *staticModelsJSON {
return &staticModelsJSON{
Claude: []*ModelInfo{testStaticModel("claude-base")},
Gemini: []*ModelInfo{testStaticModel("gemini-base")},
Vertex: []*ModelInfo{testStaticModel("vertex-base")},
GeminiCLI: []*ModelInfo{testStaticModel("gemini-cli-base")},
AIStudio: []*ModelInfo{testStaticModel("aistudio-base")},
CodexFree: []*ModelInfo{testStaticModel("free-base")},
CodexTeam: []*ModelInfo{testStaticModel("team-base")},
CodexPlus: []*ModelInfo{testStaticModel("plus-base")},
CodexPro: []*ModelInfo{testStaticModel("pro-base")},
Qwen: []*ModelInfo{testStaticModel("qwen-base")},
IFlow: []*ModelInfo{testStaticModel("iflow-base")},
Kimi: []*ModelInfo{testStaticModel("kimi-base")},
Antigravity: []*ModelInfo{testStaticModel("antigravity-base")},
}
}

func testStaticModel(id string) *ModelInfo {
return &ModelInfo{ID: id, Object: "model"}
}

func modelListContainsID(models []*ModelInfo, modelID string) bool {
for _, model := range models {
if model != nil && model.ID == modelID {
return true
}
}
return false
}

func stringSliceContains(values []string, target string) bool {
for _, value := range values {
if value == target {
return true
}
}
return false
}
78 changes: 73 additions & 5 deletions internal/registry/model_updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"sort"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -209,10 +210,10 @@ func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
{"vertex", oldData.Vertex, newData.Vertex},
{"gemini-cli", oldData.GeminiCLI, newData.GeminiCLI},
{"aistudio", oldData.AIStudio, newData.AIStudio},
{"codex", oldData.CodexFree, newData.CodexFree},
{"codex", oldData.CodexTeam, newData.CodexTeam},
{"codex", oldData.CodexPlus, newData.CodexPlus},
{"codex", oldData.CodexPro, newData.CodexPro},
{"codex", mergeStaticModelSections(oldData.CodexShared, oldData.CodexFree), mergeStaticModelSections(newData.CodexShared, newData.CodexFree)},
{"codex", mergeStaticModelSections(oldData.CodexShared, oldData.CodexTeam), mergeStaticModelSections(newData.CodexShared, newData.CodexTeam)},
{"codex", mergeStaticModelSections(oldData.CodexShared, oldData.CodexPlus), mergeStaticModelSections(newData.CodexShared, newData.CodexPlus)},
{"codex", mergeStaticModelSections(oldData.CodexShared, oldData.CodexPro), mergeStaticModelSections(newData.CodexShared, newData.CodexPro)},
{"qwen", oldData.Qwen, newData.Qwen},
{"iflow", oldData.IFlow, newData.IFlow},
{"kimi", oldData.Kimi, newData.Kimi},
Expand All @@ -225,7 +226,11 @@ func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
if seen[s.provider] {
continue
}
if modelSectionChanged(s.oldList, s.newList) {
sectionChanged := modelSectionChanged(s.oldList, s.newList)
if s.provider == "codex" {
sectionChanged = modelSectionChangedIgnoringOrder(s.oldList, s.newList)
}
if sectionChanged {
changed = append(changed, s.provider)
seen[s.provider] = true
}
Expand All @@ -249,6 +254,21 @@ func modelSectionChanged(a, b []*ModelInfo) bool {
return string(aj) != string(bj)
}

func modelSectionChangedIgnoringOrder(a, b []*ModelInfo) bool {
return modelSectionChanged(sortedModelSectionByID(a), sortedModelSectionByID(b))
}

func sortedModelSectionByID(models []*ModelInfo) []*ModelInfo {
if len(models) == 0 {
return nil
}
cloned := cloneModelInfos(models)
sort.SliceStable(cloned, func(i, j int) bool {
return strings.TrimSpace(cloned[i].ID) < strings.TrimSpace(cloned[j].ID)
})
return cloned
}

func notifyModelRefresh(changedProviders []string) {
if len(changedProviders) == 0 {
return
Expand Down Expand Up @@ -346,6 +366,24 @@ func validateModelsCatalog(data *staticModelsJSON) error {
return err
}
}
if len(data.CodexShared) > 0 {
if err := validateModelSection("codex-shared", data.CodexShared); err != nil {
return err
}
}
for _, section := range []struct {
name string
models []*ModelInfo
}{
{name: "codex-free", models: data.CodexFree},
{name: "codex-team", models: data.CodexTeam},
{name: "codex-plus", models: data.CodexPlus},
{name: "codex-pro", models: data.CodexPro},
} {
if err := validateDisjointModelSections("codex-shared", data.CodexShared, section.name, section.models); err != nil {
return err
}
}
return nil
}

Expand All @@ -370,3 +408,33 @@ func validateModelSection(section string, models []*ModelInfo) error {
}
return nil
}

func validateDisjointModelSections(leftName string, leftModels []*ModelInfo, rightName string, rightModels []*ModelInfo) error {
if len(leftModels) == 0 || len(rightModels) == 0 {
return nil
}
seen := make(map[string]struct{}, len(leftModels))
for _, model := range leftModels {
if model == nil {
continue
}
modelID := strings.TrimSpace(model.ID)
if modelID == "" {
continue
}
seen[modelID] = struct{}{}
}
for _, model := range rightModels {
if model == nil {
continue
}
modelID := strings.TrimSpace(model.ID)
if modelID == "" {
continue
}
if _, exists := seen[modelID]; exists {
return fmt.Errorf("%s overlaps %s on model id %q", leftName, rightName, modelID)
}
}
return nil
}
Loading
Loading