Skip to content

Commit b77d01e

Browse files
authored
Support get http header and x-ratelimit-* headers (#507)
* feat: add headers to http response * feat: support rate limit headers * fix: go lint * fix: test coverage * refactor streamReader * refactor streamReader * refactor: NewRateLimitHeaders to newRateLimitHeaders * refactor: RateLimitHeaders Resets filed * refactor: move RateLimitHeaders struct
1 parent 8e165dc commit b77d01e

File tree

5 files changed

+191
-5
lines changed

5 files changed

+191
-5
lines changed

chat_stream_test.go

+86-3
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
package openai_test
22

33
import (
4-
. "github.com/sashabaranov/go-openai"
5-
"github.com/sashabaranov/go-openai/internal/test/checks"
6-
74
"context"
85
"encoding/json"
96
"errors"
7+
"fmt"
108
"io"
119
"net/http"
10+
"strconv"
1211
"testing"
12+
13+
. "github.com/sashabaranov/go-openai"
14+
"github.com/sashabaranov/go-openai/internal/test/checks"
1315
)
1416

1517
func TestChatCompletionsStreamWrongModel(t *testing.T) {
@@ -178,6 +180,87 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
178180
t.Logf("%+v\n", apiErr)
179181
}
180182

183+
func TestCreateChatCompletionStreamWithHeaders(t *testing.T) {
184+
client, server, teardown := setupOpenAITestServer()
185+
defer teardown()
186+
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
187+
w.Header().Set("Content-Type", "text/event-stream")
188+
w.Header().Set(xCustomHeader, xCustomHeaderValue)
189+
190+
// Send test responses
191+
//nolint:lll
192+
dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`)
193+
dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)
194+
195+
_, err := w.Write(dataBytes)
196+
checks.NoError(t, err, "Write error")
197+
})
198+
199+
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
200+
MaxTokens: 5,
201+
Model: GPT3Dot5Turbo,
202+
Messages: []ChatCompletionMessage{
203+
{
204+
Role: ChatMessageRoleUser,
205+
Content: "Hello!",
206+
},
207+
},
208+
Stream: true,
209+
})
210+
checks.NoError(t, err, "CreateCompletionStream returned error")
211+
defer stream.Close()
212+
213+
value := stream.Header().Get(xCustomHeader)
214+
if value != xCustomHeaderValue {
215+
t.Errorf("expected %s to be %s", xCustomHeaderValue, value)
216+
}
217+
}
218+
219+
func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) {
220+
client, server, teardown := setupOpenAITestServer()
221+
defer teardown()
222+
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
223+
w.Header().Set("Content-Type", "text/event-stream")
224+
for k, v := range rateLimitHeaders {
225+
switch val := v.(type) {
226+
case int:
227+
w.Header().Set(k, strconv.Itoa(val))
228+
default:
229+
w.Header().Set(k, fmt.Sprintf("%s", v))
230+
}
231+
}
232+
233+
// Send test responses
234+
//nolint:lll
235+
dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`)
236+
dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)
237+
238+
_, err := w.Write(dataBytes)
239+
checks.NoError(t, err, "Write error")
240+
})
241+
242+
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
243+
MaxTokens: 5,
244+
Model: GPT3Dot5Turbo,
245+
Messages: []ChatCompletionMessage{
246+
{
247+
Role: ChatMessageRoleUser,
248+
Content: "Hello!",
249+
},
250+
},
251+
Stream: true,
252+
})
253+
checks.NoError(t, err, "CreateCompletionStream returned error")
254+
defer stream.Close()
255+
256+
headers := stream.GetRateLimitHeaders()
257+
bs1, _ := json.Marshal(headers)
258+
bs2, _ := json.Marshal(rateLimitHeaders)
259+
if string(bs1) != string(bs2) {
260+
t.Errorf("expected rate limit header %s to be %s", bs2, bs1)
261+
}
262+
}
263+
181264
func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
182265
client, server, teardown := setupOpenAITestServer()
183266
defer teardown()

chat_test.go

+53
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@ const (
2121
xCustomHeaderValue = "test"
2222
)
2323

24+
var (
25+
rateLimitHeaders = map[string]any{
26+
"x-ratelimit-limit-requests": 60,
27+
"x-ratelimit-limit-tokens": 150000,
28+
"x-ratelimit-remaining-requests": 59,
29+
"x-ratelimit-remaining-tokens": 149984,
30+
"x-ratelimit-reset-requests": "1s",
31+
"x-ratelimit-reset-tokens": "6m0s",
32+
}
33+
)
34+
2435
func TestChatCompletionsWrongModel(t *testing.T) {
2536
config := DefaultConfig("whatever")
2637
config.BaseURL = "http://localhost/v1"
@@ -97,6 +108,40 @@ func TestChatCompletionsWithHeaders(t *testing.T) {
97108
}
98109
}
99110

111+
// TestChatCompletionsWithRateLimitHeaders Tests the completions endpoint of the API using the mocked server.
112+
func TestChatCompletionsWithRateLimitHeaders(t *testing.T) {
113+
client, server, teardown := setupOpenAITestServer()
114+
defer teardown()
115+
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
116+
resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
117+
MaxTokens: 5,
118+
Model: GPT3Dot5Turbo,
119+
Messages: []ChatCompletionMessage{
120+
{
121+
Role: ChatMessageRoleUser,
122+
Content: "Hello!",
123+
},
124+
},
125+
})
126+
checks.NoError(t, err, "CreateChatCompletion error")
127+
128+
headers := resp.GetRateLimitHeaders()
129+
resetRequests := headers.ResetRequests.String()
130+
if resetRequests != rateLimitHeaders["x-ratelimit-reset-requests"] {
131+
t.Errorf("expected resetRequests %s to be %s", resetRequests, rateLimitHeaders["x-ratelimit-reset-requests"])
132+
}
133+
resetRequestsTime := headers.ResetRequests.Time()
134+
if resetRequestsTime.Before(time.Now()) {
135+
t.Errorf("unexpected reset requetsts: %v", resetRequestsTime)
136+
}
137+
138+
bs1, _ := json.Marshal(headers)
139+
bs2, _ := json.Marshal(rateLimitHeaders)
140+
if string(bs1) != string(bs2) {
141+
t.Errorf("expected rate limit header %s to be %s", bs2, bs1)
142+
}
143+
}
144+
100145
// TestChatCompletionsFunctions tests including a function call.
101146
func TestChatCompletionsFunctions(t *testing.T) {
102147
client, server, teardown := setupOpenAITestServer()
@@ -311,6 +356,14 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
311356
}
312357
resBytes, _ = json.Marshal(res)
313358
w.Header().Set(xCustomHeader, xCustomHeaderValue)
359+
for k, v := range rateLimitHeaders {
360+
switch val := v.(type) {
361+
case int:
362+
w.Header().Set(k, strconv.Itoa(val))
363+
default:
364+
w.Header().Set(k, fmt.Sprintf("%s", v))
365+
}
366+
}
314367
fmt.Fprintln(w, string(resBytes))
315368
}
316369

client.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,12 @@ func (h *httpHeader) SetHeader(header http.Header) {
3030
*h = httpHeader(header)
3131
}
3232

33-
func (h httpHeader) Header() http.Header {
34-
return http.Header(h)
33+
func (h *httpHeader) Header() http.Header {
34+
return http.Header(*h)
35+
}
36+
37+
func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders {
38+
return newRateLimitHeaders(h.Header())
3539
}
3640

3741
// NewClient creates new OpenAI API client.
@@ -156,6 +160,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
156160
response: resp,
157161
errAccumulator: utils.NewErrorAccumulator(),
158162
unmarshaler: &utils.JSONUnmarshaler{},
163+
httpHeader: httpHeader(resp.Header),
159164
}, nil
160165
}
161166

ratelimit.go

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package openai
2+
3+
import (
4+
"net/http"
5+
"strconv"
6+
"time"
7+
)
8+
9+
// RateLimitHeaders struct represents Openai rate limits headers.
10+
type RateLimitHeaders struct {
11+
LimitRequests int `json:"x-ratelimit-limit-requests"`
12+
LimitTokens int `json:"x-ratelimit-limit-tokens"`
13+
RemainingRequests int `json:"x-ratelimit-remaining-requests"`
14+
RemainingTokens int `json:"x-ratelimit-remaining-tokens"`
15+
ResetRequests ResetTime `json:"x-ratelimit-reset-requests"`
16+
ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"`
17+
}
18+
19+
type ResetTime string
20+
21+
func (r ResetTime) String() string {
22+
return string(r)
23+
}
24+
25+
func (r ResetTime) Time() time.Time {
26+
d, _ := time.ParseDuration(string(r))
27+
return time.Now().Add(d)
28+
}
29+
30+
func newRateLimitHeaders(h http.Header) RateLimitHeaders {
31+
limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests"))
32+
limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens"))
33+
remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests"))
34+
remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens"))
35+
return RateLimitHeaders{
36+
LimitRequests: limitReq,
37+
LimitTokens: limitTokens,
38+
RemainingRequests: remainingReq,
39+
RemainingTokens: remainingTokens,
40+
ResetRequests: ResetTime(h.Get("x-ratelimit-reset-requests")),
41+
ResetTokens: ResetTime(h.Get("x-ratelimit-reset-tokens")),
42+
}
43+
}

stream_reader.go

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ type streamReader[T streamable] struct {
2727
response *http.Response
2828
errAccumulator utils.ErrorAccumulator
2929
unmarshaler utils.Unmarshaler
30+
31+
httpHeader
3032
}
3133

3234
func (stream *streamReader[T]) Recv() (response T, err error) {

0 commit comments

Comments
 (0)