diff --git a/cmd/server/main.go b/cmd/server/main.go index 3d9ee6cf99..7353c7d90d 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -24,7 +24,6 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" - "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/store" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator" "github.com/router-for-me/CLIProxyAPI/v6/internal/tui" @@ -495,7 +494,6 @@ func main() { if standalone { // Standalone mode: start an embedded local server and connect TUI client to it. managementasset.StartAutoUpdater(context.Background(), configFilePath) - registry.StartModelsUpdater(context.Background()) hook := tui.NewLogHook(2000) hook.SetFormatter(&logging.LogFormatter{}) log.AddHook(hook) @@ -568,7 +566,6 @@ func main() { } else { // Start the main proxy service managementasset.StartAutoUpdater(context.Background(), configFilePath) - registry.StartModelsUpdater(context.Background()) cmd.StartService(cfg, configFilePath, password) } } diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 2e471ae8ca..264f3a3375 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -1505,7 +1505,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { planType := "" hashAccountID := "" if claims != nil { - planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType) + planType = registry.NormalizeCodexPlanType(claims.CodexAuthInfo.ChatgptPlanType) if accountID := claims.GetAccountID(); accountID != "" { digest := sha256.Sum256([]byte(accountID)) hashAccountID = hex.EncodeToString(digest[:])[:8] @@ -1523,6 +1523,10 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { Metadata: map[string]any{ "email": tokenStorage.Email, "account_id": tokenStorage.AccountID, + "plan_type": planType, + }, + Attributes: map[string]string{ + "plan_type": planType, }, } savedPath, errSave := h.saveTokenRecord(ctx, record) diff --git a/internal/auth/codex/filename.go b/internal/auth/codex/filename.go index fdac5a404c..6b6fa41a32 100644 --- a/internal/auth/codex/filename.go +++ b/internal/auth/codex/filename.go @@ -42,5 +42,11 @@ func normalizePlanTypeForFilename(planType string) string { for i, part := range parts { parts[i] = strings.ToLower(strings.TrimSpace(part)) } - return strings.Join(parts, "-") + normalized := strings.Join(parts, "-") + switch normalized { + case "business", "enterprise", "edu", "education": + return "team" + default: + return normalized + } } diff --git a/internal/auth/codex/filename_test.go b/internal/auth/codex/filename_test.go new file mode 100644 index 0000000000..492bb1a1bb --- /dev/null +++ b/internal/auth/codex/filename_test.go @@ -0,0 +1,11 @@ +package codex + +import "testing" + +func TestCredentialFileNameNormalizesTeamAliases(t *testing.T) { + got := CredentialFileName("tester@example.com", "business", "deadbeef", true) + want := "codex-deadbeef-tester@example.com-team.json" + if got != want { + t.Fatalf("expected %q, got %q", want, got) + } +} diff --git a/internal/registry/codex_catalog.go b/internal/registry/codex_catalog.go new file mode 100644 index 0000000000..a8f9f09e16 --- /dev/null +++ b/internal/registry/codex_catalog.go @@ -0,0 +1,158 @@ +package registry + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strings" +) + +func GetCodexModelsForPlan(planType string) []*ModelInfo { + switch NormalizeCodexPlanType(planType) { + case "pro": + return GetCodexProModels() + case "plus": + return GetCodexPlusModels() + case "team": + return GetCodexTeamModels() + case "free": + fallthrough + default: + return GetCodexFreeModels() + } +} + +func GetCodexModelsUnion() []*ModelInfo { + data := getModels() + sections := [][]*ModelInfo{ + data.CodexFree, + data.CodexTeam, + data.CodexPlus, + data.CodexPro, + } + seen := make(map[string]struct{}) + out := make([]*ModelInfo, 0) + for _, models := range sections { + for _, model := range models { + if model == nil || strings.TrimSpace(model.ID) == "" { + continue + } + if _, ok := seen[model.ID]; ok { + continue + } + seen[model.ID] = struct{}{} + out = append(out, cloneModelInfo(model)) + } + } + return out +} + +func NormalizeCodexPlanType(planType string) string { + switch strings.ToLower(strings.TrimSpace(planType)) { + case "free": + return "free" + case "team", "business", "enterprise", "edu", "education": + return "team" + case "plus": + return "plus" + case "pro": + return "pro" + default: + return "" + } +} + +func ResolveCodexPlanType(attributes map[string]string, metadata map[string]any) string { + if attributes != nil { + for _, key := range []string{"plan_type", "chatgpt_plan_type"} { + if plan := NormalizeCodexPlanType(attributes[key]); plan != "" { + return plan + } + } + } + plan, _ := EnsureCodexPlanTypeMetadata(metadata) + return plan +} + +func EnsureCodexPlanTypeMetadata(metadata map[string]any) (string, bool) { + if metadata == nil { + return "", false + } + for _, key := range []string{"plan_type", "chatgpt_plan_type"} { + if raw, ok := metadata[key].(string); ok { + if plan := NormalizeCodexPlanType(raw); plan != "" { + current, _ := metadata["plan_type"].(string) + if strings.TrimSpace(current) != plan { + metadata["plan_type"] = plan + return plan, true + } + return plan, false + } + } + } + idToken := firstString(metadata, "id_token") + if idToken == "" { + idToken = nestedString(metadata, "token", "id_token") + } + if idToken == "" { + idToken = nestedString(metadata, "tokens", "id_token") + } + if idToken == "" { + return "", false + } + plan, err := extractCodexPlanTypeFromJWT(idToken) + if err != nil { + return "", false + } + if plan == "" { + return "", false + } + metadata["plan_type"] = plan + return plan, true +} + +func firstString(metadata map[string]any, key string) string { + if metadata == nil { + return "" + } + if value, ok := metadata[key].(string); ok { + return strings.TrimSpace(value) + } + return "" +} + +func nestedString(metadata map[string]any, parent, key string) string { + if metadata == nil { + return "" + } + raw, ok := metadata[parent] + if !ok { + return "" + } + child, ok := raw.(map[string]any) + if !ok { + return "" + } + value, _ := child[key].(string) + return strings.TrimSpace(value) +} + +func extractCodexPlanTypeFromJWT(token string) (string, error) { + parts := strings.Split(strings.TrimSpace(token), ".") + if len(parts) != 3 { + return "", fmt.Errorf("invalid jwt format") + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", err + } + var claims struct { + Auth struct { + PlanType string `json:"chatgpt_plan_type"` + } `json:"https://api.openai.com/auth"` + } + if err := json.Unmarshal(payload, &claims); err != nil { + return "", err + } + return NormalizeCodexPlanType(claims.Auth.PlanType), nil +} diff --git a/internal/registry/codex_catalog_test.go b/internal/registry/codex_catalog_test.go new file mode 100644 index 0000000000..fac2ee1032 --- /dev/null +++ b/internal/registry/codex_catalog_test.go @@ -0,0 +1,106 @@ +package registry + +import ( + "encoding/base64" + "encoding/json" + "strings" + "testing" +) + +func TestEnsureCodexPlanTypeMetadataExtractsFromJWT(t *testing.T) { + metadata := map[string]any{ + "type": "codex", + "id_token": testCodexJWT(t, "team"), + } + plan, changed := EnsureCodexPlanTypeMetadata(metadata) + if !changed { + t.Fatal("expected metadata to be updated from id_token") + } + if plan != "team" { + t.Fatalf("expected team plan, got %q", plan) + } + if got, _ := metadata["plan_type"].(string); got != "team" { + t.Fatalf("expected metadata plan_type team, got %#v", metadata["plan_type"]) + } +} + +func TestEnsureCodexPlanTypeMetadataNormalizesAliases(t *testing.T) { + metadata := map[string]any{ + "type": "codex", + "plan_type": "business", + } + plan, changed := EnsureCodexPlanTypeMetadata(metadata) + if !changed { + t.Fatal("expected alias plan_type to be normalized") + } + if plan != "team" { + t.Fatalf("expected normalized team plan, got %q", plan) + } + if got, _ := metadata["plan_type"].(string); got != "team" { + t.Fatalf("expected metadata plan_type team, got %#v", metadata["plan_type"]) + } +} + +func TestGetCodexModelsForPlanUsesSafeFallback(t *testing.T) { + models := GetCodexModelsForPlan("unknown") + if len(models) == 0 { + t.Fatal("expected codex models for unknown plan fallback") + } + for _, model := range models { + if model == nil { + continue + } + if model.ID == "gpt-5.4" || model.ID == "gpt-5.3-codex-spark" { + t.Fatalf("expected unknown plan to avoid higher-tier model %q", model.ID) + } + } +} + +func TestGetStaticModelDefinitionsByChannelCodexReturnsTierUnion(t *testing.T) { + models := GetStaticModelDefinitionsByChannel("codex") + if len(models) == 0 { + t.Fatal("expected codex static model definitions") + } + seen := make(map[string]struct{}, len(models)) + for _, model := range models { + if model == nil { + continue + } + if _, ok := seen[model.ID]; ok { + t.Fatalf("duplicate codex model %q in union", model.ID) + } + seen[model.ID] = struct{}{} + } + for _, required := range []string{"gpt-5.2-codex", "gpt-5.3-codex", "gpt-5.3-codex-spark", "gpt-5.4"} { + if _, ok := seen[required]; !ok { + t.Fatalf("expected codex union to include %q", required) + } + } +} + +func TestLookupStaticModelInfoFindsCodexTierModel(t *testing.T) { + model := LookupStaticModelInfo("gpt-5.3-codex-spark") + if model == nil { + t.Fatal("expected spark model lookup to succeed") + } + if !strings.EqualFold(model.DisplayName, "GPT 5.3 Codex Spark") { + t.Fatalf("unexpected display name: %+v", model) + } +} + +func testCodexJWT(t *testing.T, planType string) string { + t.Helper() + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + payloadRaw, err := json.Marshal(map[string]any{ + "email": "tester@example.com", + "https://api.openai.com/auth": map[string]any{ + "chatgpt_account_id": "acct_123", + "chatgpt_plan_type": planType, + }, + }) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + payload := base64.RawURLEncoding.EncodeToString(payloadRaw) + return header + "." + payload + ".signature" +} diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index b7f5edb17b..c82dc421d1 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -160,7 +160,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo { case "aistudio": return GetAIStudioModels() case "codex": - return GetCodexProModels() + return GetCodexModelsUnion() case "qwen": return GetQwenModels() case "iflow": @@ -209,7 +209,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo { data.Vertex, data.GeminiCLI, data.AIStudio, - data.CodexPro, + GetCodexModelsUnion(), data.Qwen, data.IFlow, data.Kimi, diff --git a/internal/registry/model_updater.go b/internal/registry/model_updater.go index 84c9d6aa63..5f8d1f7c5b 100644 --- a/internal/registry/model_updater.go +++ b/internal/registry/model_updater.go @@ -34,6 +34,7 @@ type modelStore struct { var modelsCatalogStore = &modelStore{} var updaterOnce sync.Once +var updaterDone = make(chan struct{}) func init() { // Load embedded data as fallback on startup. @@ -43,13 +44,24 @@ func init() { } // StartModelsUpdater runs a one-time models refresh on startup. -// It blocks until the startup fetch attempt finishes so service initialization -// can wait for the refreshed catalog before registering auth-backed models. +// It returns immediately and performs the network refresh in the background. +// Every invocation registers an onUpdated callback for the same one-time startup +// attempt, so callers added after initialization still observe completion. +// Callbacks also run when the refresh falls back to the embedded catalog. // Safe to call multiple times; only one refresh will run. -func StartModelsUpdater(ctx context.Context) { +func StartModelsUpdater(ctx context.Context, onUpdated func()) { updaterOnce.Do(func() { - runModelsUpdater(ctx) + go func() { + defer close(updaterDone) + runModelsUpdater(ctx) + }() }) + if onUpdated != nil { + go func() { + <-updaterDone + onUpdated() + }() + } } func runModelsUpdater(ctx context.Context) { diff --git a/internal/registry/model_updater_test.go b/internal/registry/model_updater_test.go new file mode 100644 index 0000000000..b513552bf4 --- /dev/null +++ b/internal/registry/model_updater_test.go @@ -0,0 +1,126 @@ +package registry + +import ( + "context" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" +) + +func TestStartModelsUpdaterReturnsImmediatelyAndInvokesCallback(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + _, _ = w.Write(embeddedModelsJSON) + })) + defer server.Close() + + oldURLs := modelsURLs + oldOnce := updaterOnce + oldDone := updaterDone + modelsURLs = []string{server.URL} + updaterOnce = sync.Once{} + updaterDone = make(chan struct{}) + t.Cleanup(func() { + modelsURLs = oldURLs + updaterOnce = oldOnce + updaterDone = oldDone + }) + + updated := make(chan struct{}, 1) + start := time.Now() + StartModelsUpdater(context.Background(), func() { + select { + case updated <- struct{}{}: + default: + } + }) + if elapsed := time.Since(start); elapsed > 100*time.Millisecond { + t.Fatalf("expected StartModelsUpdater to return quickly, took %s", elapsed) + } + + select { + case <-updated: + case <-time.After(2 * time.Second): + t.Fatal("expected startup models refresh callback to fire") + } +} + +func TestStartModelsUpdaterInvokesCallbackOnRefreshFailure(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "nope", http.StatusInternalServerError) + })) + defer server.Close() + + oldURLs := modelsURLs + oldOnce := updaterOnce + oldDone := updaterDone + modelsURLs = []string{server.URL} + updaterOnce = sync.Once{} + updaterDone = make(chan struct{}) + t.Cleanup(func() { + modelsURLs = oldURLs + updaterOnce = oldOnce + updaterDone = oldDone + }) + + updated := make(chan struct{}, 1) + StartModelsUpdater(context.Background(), func() { + select { + case updated <- struct{}{}: + default: + } + }) + + select { + case <-updated: + case <-time.After(2 * time.Second): + t.Fatal("expected callback even when startup refresh falls back to embedded models") + } +} + +func TestStartModelsUpdaterNotifiesCallbacksRegisteredAfterStart(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(150 * time.Millisecond) + _, _ = w.Write(embeddedModelsJSON) + })) + defer server.Close() + + oldURLs := modelsURLs + oldOnce := updaterOnce + oldDone := updaterDone + modelsURLs = []string{server.URL} + updaterOnce = sync.Once{} + updaterDone = make(chan struct{}) + t.Cleanup(func() { + modelsURLs = oldURLs + updaterOnce = oldOnce + updaterDone = oldDone + }) + + first := make(chan struct{}, 1) + second := make(chan struct{}, 1) + + StartModelsUpdater(context.Background(), func() { + select { + case first <- struct{}{}: + default: + } + }) + time.Sleep(25 * time.Millisecond) + StartModelsUpdater(context.Background(), func() { + select { + case second <- struct{}{}: + default: + } + }) + + for name, ch := range map[string]chan struct{}{"first": first, "second": second} { + select { + case <-ch: + case <-time.After(2 * time.Second): + t.Fatalf("expected %s callback to fire", name) + } + } +} diff --git a/internal/watcher/synthesizer/file.go b/internal/watcher/synthesizer/file.go index ab54aeaaa3..8a75dda08d 100644 --- a/internal/watcher/synthesizer/file.go +++ b/internal/watcher/synthesizer/file.go @@ -10,7 +10,7 @@ import ( "strings" "time" - "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" ) @@ -137,6 +137,11 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) [] CreatedAt: now, UpdatedAt: now, } + if provider == "codex" { + if plan, _ := registry.EnsureCodexPlanTypeMetadata(metadata); plan != "" { + a.Attributes["plan_type"] = plan + } + } // Read priority from auth file. if rawPriority, ok := metadata["priority"]; ok { switch v := rawPriority.(type) { @@ -150,16 +155,6 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) [] } } ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth") - // For codex auth files, extract plan_type from the JWT id_token. - if provider == "codex" { - if idTokenRaw, ok := metadata["id_token"].(string); ok && strings.TrimSpace(idTokenRaw) != "" { - if claims, errParse := codex.ParseJWTToken(idTokenRaw); errParse == nil && claims != nil { - if pt := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); pt != "" { - a.Attributes["plan_type"] = pt - } - } - } - } if provider == "gemini-cli" { if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 { for _, v := range virtuals { diff --git a/sdk/auth/codex_device.go b/sdk/auth/codex_device.go index 10f59fb97b..d30716fcc6 100644 --- a/sdk/auth/codex_device.go +++ b/sdk/auth/codex_device.go @@ -16,6 +16,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" log "github.com/sirupsen/logrus" @@ -262,7 +263,7 @@ func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundl hashAccountID := "" if tokenStorage.IDToken != "" { if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil { - planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType) + planType = registry.NormalizeCodexPlanType(claims.CodexAuthInfo.ChatgptPlanType) accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID) if accountID != "" { digest := sha256.Sum256([]byte(accountID)) @@ -275,6 +276,9 @@ func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundl metadata := map[string]any{ "email": tokenStorage.Email, } + if planType != "" { + metadata["plan_type"] = planType + } fmt.Println("Codex authentication successful") if authBundle.APIKey != "" { diff --git a/sdk/auth/codex_device_test.go b/sdk/auth/codex_device_test.go new file mode 100644 index 0000000000..23963d46c0 --- /dev/null +++ b/sdk/auth/codex_device_test.go @@ -0,0 +1,62 @@ +package auth + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "testing" + + internalcodex "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" +) + +func TestBuildAuthRecordNormalizesTeamAliases(t *testing.T) { + authenticator := NewCodexAuthenticator() + authSvc := &internalcodex.CodexAuth{} + accountID := "acct_business_123" + email := "tester@example.com" + bundle := &internalcodex.CodexAuthBundle{ + TokenData: internalcodex.CodexTokenData{ + IDToken: testCodexDeviceJWT(t, email, accountID, "business"), + AccessToken: "access", + RefreshToken: "refresh", + AccountID: accountID, + Email: email, + }, + LastRefresh: "2026-03-10T00:00:00Z", + } + + record, err := authenticator.buildAuthRecord(authSvc, bundle) + if err != nil { + t.Fatalf("build auth record: %v", err) + } + + digest := sha256.Sum256([]byte(accountID)) + expectedName := "codex-" + hex.EncodeToString(digest[:])[:8] + "-" + email + "-team.json" + if record.FileName != expectedName { + t.Fatalf("expected file name %q, got %q", expectedName, record.FileName) + } + if got, _ := record.Metadata["plan_type"].(string); got != "team" { + t.Fatalf("expected metadata plan_type team, got %#v", record.Metadata["plan_type"]) + } + if got := record.Attributes["plan_type"]; got != "team" { + t.Fatalf("expected attribute plan_type team, got %q", got) + } +} + +func testCodexDeviceJWT(t *testing.T, email, accountID, planType string) string { + t.Helper() + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + payloadRaw, err := json.Marshal(map[string]any{ + "email": email, + "https://api.openai.com/auth": map[string]any{ + "chatgpt_account_id": accountID, + "chatgpt_plan_type": planType, + }, + }) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + payload := base64.RawURLEncoding.EncodeToString(payloadRaw) + return header + "." + payload + ".signature" +} diff --git a/sdk/auth/filestore.go b/sdk/auth/filestore.go index 987d305e88..37ea356295 100644 --- a/sdk/auth/filestore.go +++ b/sdk/auth/filestore.go @@ -15,7 +15,9 @@ import ( "sync" "time" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" ) // FileTokenStore persists token records and auth metadata using the filesystem as backing storage. @@ -197,6 +199,17 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, if provider == "" { provider = "unknown" } + if provider == "codex" { + if _, changed := registry.EnsureCodexPlanTypeMetadata(metadata); changed { + if raw, errMarshal := json.Marshal(metadata); errMarshal == nil { + if errPersist := persistNormalizedMetadata(path, raw); errPersist != nil { + log.Warnf("auth filestore: failed to persist normalized codex plan_type for %s: %v", path, errPersist) + } + } else { + log.Warnf("auth filestore: failed to marshal normalized codex metadata for %s: %v", path, errMarshal) + } + } + } if provider == "antigravity" || provider == "gemini" { projectID := "" if pid, ok := metadata["project_id"].(string); ok { @@ -254,9 +267,26 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, if email, ok := metadata["email"].(string); ok && email != "" { auth.Attributes["email"] = email } + if provider == "codex" { + if plan := registry.ResolveCodexPlanType(auth.Attributes, metadata); plan != "" { + auth.Attributes["plan_type"] = plan + } + } return auth, nil } +func persistNormalizedMetadata(path string, raw []byte) error { + file, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600) + if err != nil { + return err + } + if _, errWrite := file.Write(raw); errWrite != nil { + _ = file.Close() + return errWrite + } + return file.Close() +} + func (s *FileTokenStore) idFor(path, baseDir string) string { id := path if baseDir != "" { diff --git a/sdk/auth/filestore_codex_plan_test.go b/sdk/auth/filestore_codex_plan_test.go new file mode 100644 index 0000000000..0989aca139 --- /dev/null +++ b/sdk/auth/filestore_codex_plan_test.go @@ -0,0 +1,106 @@ +package auth + +import ( + "encoding/base64" + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestReadAuthFile_CodexRecoversPlanTypeFromIDToken(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "codex-user.json") + raw := map[string]any{ + "type": "codex", + "email": "tester@example.com", + "id_token": testCodexJWT(t, "plus"), + } + payload, err := json.Marshal(raw) + if err != nil { + t.Fatalf("marshal auth file: %v", err) + } + if err := os.WriteFile(path, payload, 0o600); err != nil { + t.Fatalf("write auth file: %v", err) + } + + store := NewFileTokenStore() + auth, err := store.readAuthFile(path, dir) + if err != nil { + t.Fatalf("read auth file: %v", err) + } + if auth == nil { + t.Fatal("expected auth") + } + if got := auth.Attributes["plan_type"]; got != "plus" { + t.Fatalf("expected plan_type plus, got %q", got) + } + updated, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read updated auth file: %v", err) + } + var metadata map[string]any + if err := json.Unmarshal(updated, &metadata); err != nil { + t.Fatalf("unmarshal updated auth file: %v", err) + } + if got, _ := metadata["plan_type"].(string); got != "plus" { + t.Fatalf("expected persisted plan_type plus, got %#v", metadata["plan_type"]) + } +} + +func TestReadAuthFile_CodexNormalizesAliasPlanTypeFromIDToken(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "codex-user.json") + raw := map[string]any{ + "type": "codex", + "email": "tester@example.com", + "id_token": testCodexJWT(t, "business"), + } + payload, err := json.Marshal(raw) + if err != nil { + t.Fatalf("marshal auth file: %v", err) + } + if err := os.WriteFile(path, payload, 0o600); err != nil { + t.Fatalf("write auth file: %v", err) + } + + store := NewFileTokenStore() + auth, err := store.readAuthFile(path, dir) + if err != nil { + t.Fatalf("read auth file: %v", err) + } + if auth == nil { + t.Fatal("expected auth") + } + if got := auth.Attributes["plan_type"]; got != "team" { + t.Fatalf("expected normalized plan_type team, got %q", got) + } + updated, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read updated auth file: %v", err) + } + var metadata map[string]any + if err := json.Unmarshal(updated, &metadata); err != nil { + t.Fatalf("unmarshal updated auth file: %v", err) + } + if got, _ := metadata["plan_type"].(string); got != "team" { + t.Fatalf("expected persisted normalized plan_type team, got %#v", metadata["plan_type"]) + } +} + +func testCodexJWT(t *testing.T, planType string) string { + t.Helper() + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + payloadRaw, err := json.Marshal(map[string]any{ + "email": "tester@example.com", + "https://api.openai.com/auth": map[string]any{ + "chatgpt_account_id": "acct_123", + "chatgpt_plan_type": planType, + }, + }) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + payload := base64.RawURLEncoding.EncodeToString(payloadRaw) + return header + "." + payload + ".signature" +} diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 596db3dd8b..1c44868fb0 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -454,6 +454,19 @@ func (s *Service) rebindExecutors() { } } +func (s *Service) rebindRegisteredModels() { + if s == nil || s.coreManager == nil { + return + } + for _, auth := range s.coreManager.List() { + if auth == nil { + continue + } + s.registerModelsForAuth(auth) + s.coreManager.RefreshSchedulerEntry(auth.ID) + } +} + // Run starts the service and blocks until the context is cancelled or the server stops. // It initializes all components including authentication, file watching, HTTP server, // and starts processing requests. The method blocks until the context is cancelled. @@ -633,6 +646,11 @@ func (s *Service) Run(ctx context.Context) error { return fmt.Errorf("cliproxy: failed to start watcher: %w", err) } log.Info("file watcher started for config and auth directory changes") + if s.cfg != nil { + registry.StartModelsUpdater(ctx, func() { + s.rebindRegisteredModels() + }) + } // Prefer core auth manager auto refresh if available. if s.coreManager != nil { @@ -829,29 +847,27 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { } models = applyExcludedModels(models, excluded) case "codex": - codexPlanType := "" - if a.Attributes != nil { - codexPlanType = strings.TrimSpace(a.Attributes["plan_type"]) - } - switch strings.ToLower(codexPlanType) { - case "pro": - models = registry.GetCodexProModels() - case "plus": - models = registry.GetCodexPlusModels() - case "team": - models = registry.GetCodexTeamModels() - case "free": - models = registry.GetCodexFreeModels() - default: - models = registry.GetCodexProModels() - } if entry := s.resolveConfigCodexKey(a); entry != nil { if len(entry.Models) > 0 { models = buildCodexConfigModels(entry) + } else { + models = registry.GetCodexModelsUnion() } if authKind == "apikey" { excluded = entry.ExcludedModels } + } else { + codexPlanType := registry.ResolveCodexPlanType(a.Attributes, a.Metadata) + if a.Attributes == nil { + a.Attributes = make(map[string]string) + } + if codexPlanType != "" { + a.Attributes["plan_type"] = codexPlanType + models = registry.GetCodexModelsForPlan(codexPlanType) + } else { + delete(a.Attributes, "plan_type") + models = registry.GetCodexModelsUnion() + } } models = applyExcludedModels(models, excluded) case "qwen": diff --git a/sdk/cliproxy/service_codex_models_test.go b/sdk/cliproxy/service_codex_models_test.go new file mode 100644 index 0000000000..d89ac034fa --- /dev/null +++ b/sdk/cliproxy/service_codex_models_test.go @@ -0,0 +1,146 @@ +package cliproxy + +import ( + "testing" + + internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +func TestRegisterModelsForAuth_CodexFreePlanSkipsHigherTierModels(t *testing.T) { + service := &Service{cfg: &config.Config{}} + auth := &coreauth.Auth{ + ID: "auth-codex-free", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"plan_type": "free"}, + } + models := registerCodexModelsForTest(t, service, auth) + assertMissingModel(t, models, "gpt-5.3-codex") + assertMissingModel(t, models, "gpt-5.4") + assertMissingModel(t, models, "gpt-5.3-codex-spark") +} + +func TestRegisterModelsForAuth_CodexTeamPlanIncludes54ButNotSpark(t *testing.T) { + service := &Service{cfg: &config.Config{}} + auth := &coreauth.Auth{ + ID: "auth-codex-team", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"plan_type": "team"}, + } + models := registerCodexModelsForTest(t, service, auth) + assertHasModel(t, models, "gpt-5.3-codex") + assertHasModel(t, models, "gpt-5.4") + assertMissingModel(t, models, "gpt-5.3-codex-spark") +} + +func TestRegisterModelsForAuth_CodexPlusPlanIncludesSpark(t *testing.T) { + service := &Service{cfg: &config.Config{}} + auth := &coreauth.Auth{ + ID: "auth-codex-plus", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"plan_type": "plus"}, + } + models := registerCodexModelsForTest(t, service, auth) + assertHasModel(t, models, "gpt-5.3-codex") + assertHasModel(t, models, "gpt-5.4") + assertHasModel(t, models, "gpt-5.3-codex-spark") +} + +func TestRegisterModelsForAuth_CodexUnknownPlanFallsBackToUnion(t *testing.T) { + service := &Service{cfg: &config.Config{}} + auth := &coreauth.Auth{ + ID: "auth-codex-unknown", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{}, + } + models := registerCodexModelsForTest(t, service, auth) + assertHasModel(t, models, "gpt-5.3-codex") + assertHasModel(t, models, "gpt-5.4") + assertHasModel(t, models, "gpt-5.3-codex-spark") +} + +func TestRegisterModelsForAuth_CodexAPIKeyWithoutModelsUsesUnion(t *testing.T) { + service := &Service{ + cfg: &config.Config{ + CodexKey: []config.CodexKey{{ + APIKey: "sk-test", + }}, + }, + } + auth := &coreauth.Auth{ + ID: "auth-codex-apikey", + Provider: "codex", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "auth_kind": "apikey", + "api_key": "sk-test", + }, + Metadata: map[string]any{}, + } + models := registerCodexModelsForTest(t, service, auth) + assertHasModel(t, models, "gpt-5.3-codex") + assertHasModel(t, models, "gpt-5.4") + assertHasModel(t, models, "gpt-5.3-codex-spark") +} + +func TestRegisterModelsForAuth_CodexAPIKeyExplicitModelsOverrideUnion(t *testing.T) { + service := &Service{ + cfg: &config.Config{ + CodexKey: []config.CodexKey{{ + APIKey: "sk-test-explicit", + Models: []internalconfig.CodexModel{{ + Name: "gpt-5.4", + Alias: "gpt-5.4", + }}, + }}, + }, + } + auth := &coreauth.Auth{ + ID: "auth-codex-apikey-explicit", + Provider: "codex", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "auth_kind": "apikey", + "api_key": "sk-test-explicit", + }, + Metadata: map[string]any{}, + } + models := registerCodexModelsForTest(t, service, auth) + assertHasModel(t, models, "gpt-5.4") + assertMissingModel(t, models, "gpt-5.3-codex") + assertMissingModel(t, models, "gpt-5.3-codex-spark") +} + +func registerCodexModelsForTest(t *testing.T, service *Service, auth *coreauth.Auth) []*registry.ModelInfo { + t.Helper() + reg := registry.GetGlobalRegistry() + reg.UnregisterClient(auth.ID) + t.Cleanup(func() { reg.UnregisterClient(auth.ID) }) + service.registerModelsForAuth(auth) + return reg.GetModelsForClient(auth.ID) +} + +func assertHasModel(t *testing.T, models []*registry.ModelInfo, id string) { + t.Helper() + for _, model := range models { + if model != nil && model.ID == id { + return + } + } + t.Fatalf("expected model %q, got %+v", id, models) +} + +func assertMissingModel(t *testing.T, models []*registry.ModelInfo, id string) { + t.Helper() + for _, model := range models { + if model != nil && model.ID == id { + t.Fatalf("did not expect model %q, got %+v", id, models) + } + } +}