Skip to content

Commit be2e238

Browse files
authored
feat: add Anthropic API support with custom version header (#934)
* feat: add Anthropic API support with custom version header * refactor: use switch statement for API type header handling * refactor: add OpenAI & AzureAD types to be exhaustive * Update client.go need explicit fallthrough in empty case statements * constant for APIVersion; addtl tests
1 parent 85f578b commit be2e238

File tree

4 files changed

+89
-6
lines changed

4 files changed

+89
-6
lines changed

client.go

+13-5
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,21 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
182182

183183
func (c *Client) setCommonHeaders(req *http.Request) {
184184
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
185-
// Azure API Key authentication
186-
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
185+
switch c.config.APIType {
186+
case APITypeAzure, APITypeCloudflareAzure:
187+
// Azure API Key authentication
187188
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
188-
} else if c.config.authToken != "" {
189-
// OpenAI or Azure AD authentication
190-
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
189+
case APITypeAnthropic:
190+
// https://docs.anthropic.com/en/api/versioning
191+
req.Header.Set("anthropic-version", c.config.APIVersion)
192+
case APITypeOpenAI, APITypeAzureAD:
193+
fallthrough
194+
default:
195+
if c.config.authToken != "" {
196+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
197+
}
191198
}
199+
192200
if c.config.OrgID != "" {
193201
req.Header.Set("OpenAI-Organization", c.config.OrgID)
194202
}

client_test.go

+15
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@ func TestClient(t *testing.T) {
3939
}
4040
}
4141

42+
func TestSetCommonHeadersAnthropic(t *testing.T) {
43+
config := DefaultAnthropicConfig("mock-token", "")
44+
client := NewClientWithConfig(config)
45+
req, err := http.NewRequest("GET", "http://example.com", nil)
46+
if err != nil {
47+
t.Fatalf("Failed to create request: %v", err)
48+
}
49+
50+
client.setCommonHeaders(req)
51+
52+
if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion {
53+
t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got)
54+
}
55+
}
56+
4257
func TestDecodeResponse(t *testing.T) {
4358
stringInput := ""
4459

config.go

+21-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ const (
1111

1212
azureAPIPrefix = "openai"
1313
azureDeploymentsPrefix = "deployments"
14+
15+
AnthropicAPIVersion = "2023-06-01"
1416
)
1517

1618
type APIType string
@@ -20,6 +22,7 @@ const (
2022
APITypeAzure APIType = "AZURE"
2123
APITypeAzureAD APIType = "AZURE_AD"
2224
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
25+
APITypeAnthropic APIType = "ANTHROPIC"
2326
)
2427

2528
const AzureAPIKeyHeader = "api-key"
@@ -37,7 +40,7 @@ type ClientConfig struct {
3740
BaseURL string
3841
OrgID string
3942
APIType APIType
40-
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
43+
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic
4144
AssistantVersion string
4245
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
4346
HTTPClient HTTPDoer
@@ -76,6 +79,23 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
7679
}
7780
}
7881

82+
func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig {
83+
if baseURL == "" {
84+
baseURL = "https://api.anthropic.com/v1"
85+
}
86+
return ClientConfig{
87+
authToken: apiKey,
88+
BaseURL: baseURL,
89+
OrgID: "",
90+
APIType: APITypeAnthropic,
91+
APIVersion: AnthropicAPIVersion,
92+
93+
HTTPClient: &http.Client{},
94+
95+
EmptyMessagesLimit: defaultEmptyMessagesLimit,
96+
}
97+
}
98+
7999
func (ClientConfig) String() string {
80100
return "<OpenAI API ClientConfig>"
81101
}

config_test.go

+40
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,43 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
6060
})
6161
}
6262
}
63+
64+
func TestDefaultAnthropicConfig(t *testing.T) {
65+
apiKey := "test-key"
66+
baseURL := "https://api.anthropic.com/v1"
67+
68+
config := openai.DefaultAnthropicConfig(apiKey, baseURL)
69+
70+
if config.APIType != openai.APITypeAnthropic {
71+
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
72+
}
73+
74+
if config.APIVersion != openai.AnthropicAPIVersion {
75+
t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion)
76+
}
77+
78+
if config.BaseURL != baseURL {
79+
t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL)
80+
}
81+
82+
if config.EmptyMessagesLimit != 300 {
83+
t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit)
84+
}
85+
}
86+
87+
func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) {
88+
config := openai.DefaultAnthropicConfig("", "")
89+
90+
if config.APIType != openai.APITypeAnthropic {
91+
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
92+
}
93+
94+
if config.APIVersion != openai.AnthropicAPIVersion {
95+
t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion)
96+
}
97+
98+
expectedBaseURL := "https://api.anthropic.com/v1"
99+
if config.BaseURL != expectedBaseURL {
100+
t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL)
101+
}
102+
}

0 commit comments

Comments
 (0)