From 3c9c80c3049c7038df406bcfeae7023d668e743a Mon Sep 17 00:00:00 2001 From: Abhirup Date: Tue, 4 Nov 2025 00:39:50 +0530 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A5=20feat:=20Add=20StreamResponseBody?= =?UTF-8?q?=20support=20for=20the=20Client?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 14 + client/client_test.go | 62 +++ client/core.go | 16 +- client/core_test.go | 7 + client/response.go | 49 ++- client/response_test.go | 833 +++++++++++++++++++++++++++++++++++++++ client/transport.go | 41 ++ client/transport_test.go | 200 ++++++++++ ctx_test.go | 4 +- docs/client/rest.md | 16 + 10 files changed, 1226 insertions(+), 16 deletions(-) diff --git a/client/client.go b/client/client.go index 4a630ae5600..f62480ea83e 100644 --- a/client/client.go +++ b/client/client.go @@ -526,6 +526,20 @@ func (c *Client) DisableDebug() *Client { return c } +// StreamResponseBody returns the current StreamResponseBody setting. +func (c *Client) StreamResponseBody() bool { + return c.transport.StreamResponseBody() +} + +// SetStreamResponseBody enables or disables response body streaming. +// When enabled, the response body can be read as a stream using BodyStream() +// instead of being fully loaded into memory. This is useful for large responses +// or server-sent events. +func (c *Client) SetStreamResponseBody(enable bool) *Client { + c.transport.SetStreamResponseBody(enable) + return c +} + // SetCookieJar sets the cookie jar for the client. func (c *Client) SetCookieJar(cookieJar *CookieJar) *Client { c.cookieJar = cookieJar diff --git a/client/client_test.go b/client/client_test.go index 9fb14496b23..70388944135 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -2343,3 +2343,65 @@ func Benchmark_Client_Request_Send_ContextCancel(b *testing.B) { require.ErrorIs(b, <-errCh, ErrTimeoutOrCancel) } } + +func Test_Client_StreamResponseBody(t *testing.T) { + t.Parallel() + + t.Run("default value", func(t *testing.T) { + t.Parallel() + client := New() + require.False(t, client.StreamResponseBody()) + }) + + t.Run("enable streaming", func(t *testing.T) { + t.Parallel() + client := New() + result := client.SetStreamResponseBody(true) + require.True(t, client.StreamResponseBody()) + require.Equal(t, client, result) + }) + + t.Run("disable streaming", func(t *testing.T) { + t.Parallel() + client := New() + client.SetStreamResponseBody(true) + require.True(t, client.StreamResponseBody()) + client.SetStreamResponseBody(false) + require.False(t, client.StreamResponseBody()) + }) + + t.Run("with standard client", func(t *testing.T) { + t.Parallel() + client := New() + client.SetStreamResponseBody(true) + require.True(t, client.StreamResponseBody()) + }) + + t.Run("with host client", func(t *testing.T) { + t.Parallel() + hostClient := &fasthttp.HostClient{} + client := NewWithHostClient(hostClient) + client.SetStreamResponseBody(true) + require.True(t, client.StreamResponseBody()) + require.True(t, hostClient.StreamResponseBody) + }) + + t.Run("with lb client", func(t *testing.T) { + t.Parallel() + lbClient := &fasthttp.LBClient{ + Clients: []fasthttp.BalancingClient{ + &fasthttp.HostClient{Addr: "example.com:80"}, + }, + } + client := NewWithLBClient(lbClient) + client.SetStreamResponseBody(true) + require.True(t, client.StreamResponseBody()) + }) + + t.Run("getter with standard client without setter", func(t *testing.T) { + t.Parallel() + client := New() + // Test getter directly without calling setter + require.False(t, client.StreamResponseBody()) + }) +} diff --git a/client/core.go b/client/core.go index 13084bd28b6..48ab60c70a6 100644 --- a/client/core.go +++ b/client/core.go @@ -86,9 +86,16 @@ func (c *core) execFunc() (*Response, error) { defer fasthttp.ReleaseRequest(reqv) respv := fasthttp.AcquireResponse() - defer fasthttp.ReleaseResponse(respv) + defer func() { + if respv != nil { + fasthttp.ReleaseResponse(respv) + } + }() c.req.RawRequest.CopyTo(reqv) + if bodyStream := c.req.RawRequest.BodyStream(); bodyStream != nil { + reqv.SetBodyStream(bodyStream, c.req.RawRequest.Header.ContentLength()) + } var err error if cfg != nil { @@ -115,7 +122,12 @@ func (c *core) execFunc() (*Response, error) { resp := AcquireResponse() resp.setClient(c.client) resp.setRequest(c.req) - respv.CopyTo(resp.RawResponse) + originalRaw := resp.RawResponse + resp.RawResponse = respv + respv = nil + if originalRaw != nil { + fasthttp.ReleaseResponse(originalRaw) + } respChan <- resp }() diff --git a/client/core_test.go b/client/core_test.go index 6d5bace1e93..c20d4e7ad1d 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -383,6 +383,13 @@ func (*blockingErrTransport) Client() any { return nil } +func (*blockingErrTransport) StreamResponseBody() bool { + return false +} + +func (*blockingErrTransport) SetStreamResponseBody(_ bool) { +} + func (b *blockingErrTransport) release() { b.releaseOnce.Do(func() { close(b.unblock) }) } diff --git a/client/response.go b/client/response.go index c6a3bcd89bc..2c051a99189 100644 --- a/client/response.go +++ b/client/response.go @@ -11,7 +11,7 @@ import ( "path/filepath" "sync" - "github.com/gofiber/utils/v2" + utils "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -89,22 +89,38 @@ func (r *Response) Body() []byte { return r.RawResponse.Body() } +// BodyStream returns the response body as a stream reader. +// Note: When using BodyStream(), the response body is not copied to memory, +// so calling Body() afterwards may return an empty slice. +func (r *Response) BodyStream() io.Reader { + if stream := r.RawResponse.BodyStream(); stream != nil { + return stream + } + // If streaming is not enabled, return a bytes.Reader from the regular body + return bytes.NewReader(r.RawResponse.Body()) +} + +// IsStreaming returns true if the response body is being streamed. +func (r *Response) IsStreaming() bool { + return r.RawResponse.BodyStream() != nil +} + // String returns the response body as a trimmed string. func (r *Response) String() string { return utils.Trim(string(r.Body()), ' ') } -// JSON unmarshal the response body into the given interface{} using JSON. +// JSON unmarshals the response body into the given interface{} using JSON. func (r *Response) JSON(v any) error { return r.client.jsonUnmarshal(r.Body(), v) } -// CBOR unmarshal the response body into the given interface{} using CBOR. +// CBOR unmarshals the response body into the given interface{} using CBOR. func (r *Response) CBOR(v any) error { return r.client.cborUnmarshal(r.Body(), v) } -// XML unmarshal the response body into the given interface{} using XML. +// XML unmarshals the response body into the given interface{} using XML. func (r *Response) XML(v any) error { return r.client.xmlUnmarshal(r.Body(), v) } @@ -136,21 +152,30 @@ func (r *Response) Save(v any) error { } defer func() { _ = outFile.Close() }() //nolint:errcheck // not needed - if _, err = io.Copy(outFile, bytes.NewReader(r.Body())); err != nil { + if r.IsStreaming() { + _, err = io.Copy(outFile, r.BodyStream()) + } else { + _, err = io.Copy(outFile, bytes.NewReader(r.Body())) + } + + if err != nil { return fmt.Errorf("failed to write response body to file: %w", err) } return nil case io.Writer: - if _, err := io.Copy(p, bytes.NewReader(r.Body())); err != nil { - return fmt.Errorf("failed to write response body to io.Writer: %w", err) + var err error + if r.IsStreaming() { + _, err = io.Copy(p, r.BodyStream()) + } else { + _, err = io.Copy(p, bytes.NewReader(r.Body())) } - defer func() { - if pc, ok := p.(io.WriteCloser); ok { - _ = pc.Close() //nolint:errcheck // not needed - } - }() + + if err != nil { + return fmt.Errorf("failed to write response body to writer: %w", err) + } + return nil default: diff --git a/client/response_test.go b/client/response_test.go index 200a5b9f73b..7557c8ff9f7 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -1,13 +1,18 @@ package client import ( + "bufio" "bytes" "crypto/tls" + "encoding/json" "encoding/xml" + "fmt" "io" "net" "os" + "strings" "testing" + "time" "github.com/gofiber/fiber/v3/internal/tlstest" @@ -537,4 +542,832 @@ func Test_Response_Save(t *testing.T) { err = resp.Save(nil) require.Error(t, err) }) + + t.Run("streaming with io.Writer without closing", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/stream", func(c fiber.Ctx) error { + c.Set("Transfer-Encoding", "chunked") + data := make([]byte, 1024*8) // 8KB + for i := range data { + data[i] = byte('S') + } + return c.SendStreamWriter(func(w *bufio.Writer) { + if _, err := w.Write(data); err != nil { + return + } + if err := w.Flush(); err != nil { + return + } + }) + }) + }) + defer server.stop() + + client := New().SetDial(server.dial()).SetStreamResponseBody(true) + + resp, err := client.Get("http://example.com/stream") + require.NoError(t, err) + defer resp.Close() + + // Custom writer that tracks if it's closed + closableBuffer := &testClosableBuffer{} + + err = resp.Save(closableBuffer) + require.NoError(t, err) + + // Check content + require.Contains(t, closableBuffer.String(), "SSSSSS") + + // Check that the writer was not closed by Save() + require.False(t, closableBuffer.closed, "Save() should not close the writer") + }) + + t.Run("streaming with io.Writer error during copy", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/stream", func(c fiber.Ctx) error { + c.Set("Transfer-Encoding", "chunked") + return c.SendStreamWriter(func(w *bufio.Writer) { + data := []byte("streaming data that will fail to write") + if _, err := w.Write(data); err != nil { + return + } + if err := w.Flush(); err != nil { + return + } + }) + }) + }) + defer server.stop() + + client := New().SetDial(server.dial()).SetStreamResponseBody(true) + + resp, err := client.Get("http://example.com/stream") + require.NoError(t, err) + defer resp.Close() + + // Use a writer that will fail after a few bytes + errorWriter := &testErrorWriter{maxBytes: 5} + + err = resp.Save(errorWriter) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to write response body to writer") + require.Contains(t, err.Error(), "write error after 5 bytes") + }) +} + +// testClosableBuffer is a helper for testing writers that should not be closed. +type testClosableBuffer struct { + bytes.Buffer + closed bool +} + +// Close implements the io.Closer interface. +func (tcb *testClosableBuffer) Close() error { + tcb.closed = true + return nil +} + +// testErrorWriter is a helper for testing write errors during io.CopyBuffer. +type testErrorWriter struct { + maxBytes int + written int +} + +func (tew *testErrorWriter) Write(p []byte) (int, error) { + if tew.written >= tew.maxBytes { + return 0, fmt.Errorf("write error after %d bytes", tew.maxBytes) + } + + remainingBytes := tew.maxBytes - tew.written + if len(p) <= remainingBytes { + tew.written += len(p) + return len(p), nil + } + + // Write only up to maxBytes, then return error + tew.written += remainingBytes + return remainingBytes, fmt.Errorf("write error after %d bytes", tew.maxBytes) +} + +func Test_Response_BodyStream(t *testing.T) { + t.Parallel() + + t.Run("basic streaming", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/stream", func(c fiber.Ctx) error { + return c.SendStream(bytes.NewReader([]byte("streaming data"))) + }) + }) + defer server.stop() + + client := New().SetDial(server.dial()).SetStreamResponseBody(true) + + resp, err := client.Get("http://example.com/stream") + require.NoError(t, err) + defer resp.Close() + bodyStream := resp.BodyStream() + require.NotNil(t, bodyStream) + data, err := io.ReadAll(bodyStream) + require.NoError(t, err) + require.Equal(t, "streaming data", string(data)) + }) + + t.Run("large response streaming", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/large", func(c fiber.Ctx) error { + data := make([]byte, 1024) + for i := range data { + data[i] = byte('A' + i%26) + } + return c.SendStream(bytes.NewReader(data)) + }) + }) + defer server.stop() + + client := New().SetDial(server.dial()).SetStreamResponseBody(true) + resp, err := client.Get("http://example.com/large") + require.NoError(t, err) + defer resp.Close() + bodyStream := resp.BodyStream() + require.NotNil(t, bodyStream) + buffer := make([]byte, 256) + var totalRead []byte + for { + n, err := bodyStream.Read(buffer) + if n > 0 { + totalRead = append(totalRead, buffer[:n]...) + } + if err == io.EOF { + break + } + require.NoError(t, err) + } + require.Len(t, totalRead, 1024) + for i := 0; i < 1024; i++ { + expected := byte('A' + i%26) + require.Equal(t, expected, totalRead[i]) + } + }) + + t.Run("compare with regular body", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/stream", func(c fiber.Ctx) error { + c.Set("Transfer-Encoding", "chunked") + return c.SendStreamWriter(func(w *bufio.Writer) { + data := []byte("streaming data") + if _, err := w.Write(data); err != nil { + return + } + if err := w.Flush(); err != nil { + return + } + }) + }) + }) + defer server.stop() + + client1 := New().SetDial(server.dial()) + resp1, err := client1.Get("http://example.com/stream") + require.NoError(t, err) + defer resp1.Close() + normalBody := resp1.Body() + client2 := New().SetDial(server.dial()).SetStreamResponseBody(true) + resp2, err := client2.Get("http://example.com/stream") + require.NoError(t, err) + defer resp2.Close() + streamedBody, err := io.ReadAll(resp2.BodyStream()) + require.NoError(t, err) + require.Equal(t, normalBody, streamedBody) + }) + + t.Run("chunked streaming with delays", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/chunked", func(c fiber.Ctx) error { + c.Set("Content-Type", "text/plain") + chunks := []string{"chunk1", "chunk2", "chunk3"} + return c.SendStreamWriter(func(w *bufio.Writer) { + for i, chunk := range chunks { + if _, err := w.WriteString(chunk); err != nil { + return + } + if err := w.Flush(); err != nil { + return + } + if i < len(chunks)-1 { + time.Sleep(10 * time.Millisecond) // Shorter delay for faster tests + } + } + }) + }) + }) + defer server.stop() + + client := New().SetDial(server.dial()).SetStreamResponseBody(true) + + resp, err := client.Get("http://example.com/chunked") + require.NoError(t, err) + defer resp.Close() + + bodyStream := resp.BodyStream() + require.NotNil(t, bodyStream) + + var receivedChunks []string + buffer := make([]byte, 10) + + for { + n, err := bodyStream.Read(buffer) + if n > 0 { + chunk := string(buffer[:n]) + receivedChunks = append(receivedChunks, chunk) + } + if err == io.EOF { + break + } + require.NoError(t, err) + } + + fullContent := strings.Join(receivedChunks, "") + require.Equal(t, "chunk1chunk2chunk3", fullContent) + require.GreaterOrEqual(t, len(receivedChunks), 1, "Should receive data chunks") + }) + + t.Run("server sent events with incremental reads", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/sse", func(c fiber.Ctx) error { + c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + + messages := []string{ + "data: event 1\n\n", + "data: event 2\n\n", + "data: event 3\n\n", + "data: event 4\n\n", + } + + return c.SendStreamWriter(func(w *bufio.Writer) { + for i, msg := range messages { + if _, err := w.WriteString(msg); err != nil { + return + } + if err := w.Flush(); err != nil { + return + } + if i < len(messages)-1 { + time.Sleep(5 * time.Millisecond) + } + } + }) + }) + }) + defer server.stop() + + client := New().SetDial(server.dial()).SetStreamResponseBody(true) + + resp, err := client.Get("http://example.com/sse") + require.NoError(t, err) + defer resp.Close() + + bodyStream := resp.BodyStream() + require.NotNil(t, bodyStream) + + data, err := io.ReadAll(bodyStream) + require.NoError(t, err) + + content := string(data) + require.Contains(t, content, "event 1") + require.Contains(t, content, "event 2") + require.Contains(t, content, "event 3") + require.Contains(t, content, "event 4") + require.Contains(t, content, "data: event") + require.Contains(t, content, "\n\n") + }) + + t.Run("progressive json streaming", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/json-stream", func(c fiber.Ctx) error { + c.Set("Content-Type", "application/json") + jsonParts := []string{ + `[`, + `{"id":1,"name":"item1"},`, + `{"id":2,"name":"item2"},`, + `{"id":3,"name":"item3"}`, + `]`, + } + return c.SendStreamWriter(func(w *bufio.Writer) { + for i, part := range jsonParts { + if _, err := w.WriteString(part); err != nil { + return + } + if err := w.Flush(); err != nil { + return + } + + if i < len(jsonParts)-1 { + time.Sleep(2 * time.Millisecond) + } + } + }) + }) + }) + defer server.stop() + + client := New().SetDial(server.dial()).SetStreamResponseBody(true) + + resp, err := client.Get("http://example.com/json-stream") + require.NoError(t, err) + defer resp.Close() + + bodyStream := resp.BodyStream() + require.NotNil(t, bodyStream) + + data, err := io.ReadAll(bodyStream) + require.NoError(t, err) + + fullJSON := string(data) + require.JSONEq(t, `[{"id":1,"name":"item1"},{"id":2,"name":"item2"},{"id":3,"name":"item3"}]`, fullJSON) + var items []map[string]any + err = json.Unmarshal([]byte(fullJSON), &items) + require.NoError(t, err) + require.Len(t, items, 3) + }) + + t.Run("connection interruption handling", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/interrupted", func(c fiber.Ctx) error { + c.Set("Content-Type", "text/plain") + c.Response().ImmediateHeaderFlush = true + + return c.SendStreamWriter(func(w *bufio.Writer) { + if _, err := w.WriteString("initial data"); err != nil { + return + } + if err := w.Flush(); err != nil { + return + } + + time.Sleep(10 * time.Millisecond) + + if _, err := w.WriteString(" more data"); err != nil { + return + } + if err := w.Flush(); err != nil { + return + } + }) + }) + }) + defer server.stop() + + client := New().SetDial(server.dial()).SetStreamResponseBody(true) + + resp, err := client.Get("http://example.com/interrupted") + require.NoError(t, err) + + bodyStream := resp.BodyStream() + require.NotNil(t, bodyStream) + + buffer := make([]byte, 64) + n, err := bodyStream.Read(buffer) + require.NoError(t, err) + require.Contains(t, string(buffer[:n]), "initial") + + // Close the response - this will invalidate the stream + resp.Close() + + // Test that reading after close doesn't crash + // The behavior is undefined - it may return an error or panic + // We use recover to ensure it doesn't crash the test + func() { + defer func() { + if r := recover(); r != nil { + // Panic is acceptable - stream was closed + t.Logf("Reading after close caused panic (expected): %v", r) + } + }() + _, readErr := bodyStream.Read(buffer) + if readErr != nil { + // Error is also acceptable + t.Logf("Reading after close returned error (expected): %v", readErr) + } + }() + }) + + t.Run("large response streaming validation", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/large", func(c fiber.Ctx) error { + c.Set("Content-Type", "text/plain") + c.Response().ImmediateHeaderFlush = true + + const chunkSize = 1024 + const numChunks = 10 + + return c.SendStreamWriter(func(w *bufio.Writer) { + for i := 0; i < numChunks; i++ { + chunk := make([]byte, chunkSize) + for j := 0; j < chunkSize; j++ { + chunk[j] = byte('A' + ((i*chunkSize + j) % 26)) + } + + if _, err := w.Write(chunk); err != nil { + return + } + if err := w.Flush(); err != nil { + return + } + + if i < numChunks-1 { + time.Sleep(time.Millisecond) + } + } + }) + }) + }) + defer server.stop() + + client := New().SetDial(server.dial()).SetStreamResponseBody(true) + + resp, err := client.Get("http://example.com/large") + require.NoError(t, err) + defer resp.Close() + + bodyStream := resp.BodyStream() + require.NotNil(t, bodyStream) + + buffer := make([]byte, 512) + var totalRead []byte + readCount := 0 + + for { + n, err := bodyStream.Read(buffer) + if n > 0 { + totalRead = append(totalRead, buffer[:n]...) + readCount++ + } + if err == io.EOF { + break + } + require.NoError(t, err) + } + + expectedSize := 1024 * 10 + require.Len(t, totalRead, expectedSize) + require.Greater(t, readCount, 1, "Should have made multiple reads for streaming") + for i := 0; i < expectedSize; i++ { + expected := byte('A' + (i % 26)) + require.Equal(t, expected, totalRead[i], "Data pattern mismatch at position %d", i) + } + }) + + t.Run("stream object identity when streaming enabled", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/stream", func(c fiber.Ctx) error { + // Use chunked encoding to force streaming + c.Set("Transfer-Encoding", "chunked") + data := make([]byte, 1024*8) // 8KB + for i := range data { + data[i] = byte('S') + } + return c.SendStream(bytes.NewReader(data)) + }) + }) + defer server.stop() + + client := New().SetDial(server.dial()).SetStreamResponseBody(true) + resp, err := client.Get("http://example.com/stream") + require.NoError(t, err) + defer resp.Close() + rawStream := resp.RawResponse.BodyStream() + if rawStream != nil { + require.True(t, resp.IsStreaming()) + bodyStream := resp.BodyStream() + require.NotNil(t, bodyStream) + require.Same(t, rawStream, bodyStream, "BodyStream() should return the exact same stream object when RawResponse.BodyStream() is not nil") + } else { + require.False(t, resp.IsStreaming()) + bodyStream := resp.BodyStream() + require.NotNil(t, bodyStream) + _, ok := bodyStream.(*bytes.Reader) + require.True(t, ok, "When RawResponse.BodyStream() is nil, BodyStream() should return a *bytes.Reader") + } + bodyStream := resp.BodyStream() + data, err := io.ReadAll(bodyStream) + require.NoError(t, err) + require.Len(t, data, 1024*8) + for _, b := range data { + require.Equal(t, byte('S'), b) + } + }) +} + +func Test_Response_BodyStream_Fallback(t *testing.T) { + t.Parallel() + t.Run("non-streaming response fallback to bytes.Reader", func(t *testing.T) { + t.Parallel() + server := startTestServer(t, func(app *fiber.App) { + app.Get("/regular", func(c fiber.Ctx) error { + return c.SendString("regular response body") + }) + }) + defer server.stop() + client := New().SetDial(server.dial()) + resp, err := client.Get("http://example.com/regular") + require.NoError(t, err) + defer resp.Close() + require.False(t, resp.IsStreaming()) + bodyStream := resp.BodyStream() + require.NotNil(t, bodyStream) + data, err := io.ReadAll(bodyStream) + require.NoError(t, err) + require.Equal(t, "regular response body", string(data)) + bodyBytes := resp.Body() + require.Equal(t, "regular response body", string(bodyBytes)) + }) + + t.Run("empty response body stream", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/empty", func(c fiber.Ctx) error { + return c.SendStream(bytes.NewReader([]byte{})) + }) + }) + defer server.stop() + client := New().SetDial(server.dial()) + resp, err := client.Get("http://example.com/empty") + require.NoError(t, err) + defer resp.Close() + require.False(t, resp.IsStreaming()) + + bodyStream := resp.BodyStream() + require.NotNil(t, bodyStream) + data, err := io.ReadAll(bodyStream) + require.NoError(t, err) + require.Empty(t, string(data)) + }) + + t.Run("large non-streaming response", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/large", func(c fiber.Ctx) error { + data := make([]byte, 50*1024) // 50KB + for i := range data { + data[i] = byte('X') + } + return c.Send(data) + }) + }) + defer server.stop() + client := New().SetDial(server.dial()) + resp, err := client.Get("http://example.com/large") + require.NoError(t, err) + defer resp.Close() + require.False(t, resp.IsStreaming()) + bodyStream := resp.BodyStream() + require.NotNil(t, bodyStream) + buffer := make([]byte, 1024) + var totalRead []byte + for { + n, err := bodyStream.Read(buffer) + if n > 0 { + totalRead = append(totalRead, buffer[:n]...) + } + if err == io.EOF { + break + } + require.NoError(t, err) + } + + require.Len(t, totalRead, 50*1024) + for i, b := range totalRead { + require.Equal(t, byte('X'), b, "Byte mismatch at position %d", i) + } + }) +} + +func Test_Response_IsStreaming(t *testing.T) { + t.Parallel() + + t.Run("streaming with large response", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/large-stream", func(c fiber.Ctx) error { + data := make([]byte, 64*1024) // 64KB + for i := range data { + data[i] = byte('S') + } + return c.SendStream(bytes.NewReader(data)) + }) + }) + defer server.stop() + client := New().SetDial(server.dial()).SetStreamResponseBody(true) + resp, err := client.Get("http://example.com/large-stream") + require.NoError(t, err) + defer resp.Close() + bodyStream := resp.BodyStream() + require.NotNil(t, bodyStream) + data, err := io.ReadAll(bodyStream) + require.NoError(t, err) + require.Len(t, data, 64*1024) + }) + + t.Run("streaming disabled", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/regular", func(c fiber.Ctx) error { + return c.SendString("regular content") + }) + }) + defer server.stop() + client := New().SetDial(server.dial()) + resp, err := client.Get("http://example.com/regular") + require.NoError(t, err) + defer resp.Close() + require.False(t, resp.IsStreaming()) + }) + + t.Run("bodystream always works regardless of streaming state", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/test", func(c fiber.Ctx) error { + return c.SendString("test content") + }) + }) + defer server.stop() + client1 := New().SetDial(server.dial()).SetStreamResponseBody(true) + resp1, err := client1.Get("http://example.com/test") + require.NoError(t, err) + defer resp1.Close() + bodyStream1 := resp1.BodyStream() + require.NotNil(t, bodyStream1) + data1, err := io.ReadAll(bodyStream1) + require.NoError(t, err) + require.Equal(t, "test content", string(data1)) + client2 := New().SetDial(server.dial()).SetStreamResponseBody(false) + resp2, err := client2.Get("http://example.com/test") + require.NoError(t, err) + defer resp2.Close() + require.False(t, resp2.IsStreaming()) + bodyStream2 := resp2.BodyStream() + require.NotNil(t, bodyStream2) + data2, err := io.ReadAll(bodyStream2) + require.NoError(t, err) + require.Equal(t, "test content", string(data2)) + require.Equal(t, string(data1), string(data2)) + }) +} + +func Test_Response_Save_Streaming_FilePath(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/stream", func(c fiber.Ctx) error { + c.Set("Transfer-Encoding", "chunked") + return c.SendStreamWriter(func(w *bufio.Writer) { + data := []byte("streaming file content") + if _, err := w.Write(data); err != nil { + return + } + if err := w.Flush(); err != nil { + return + } + }) + }) + }) + defer server.stop() + + client := New().SetDial(server.dial()).SetStreamResponseBody(true) + + resp, err := client.Get("http://example.com/stream") + require.NoError(t, err) + defer resp.Close() + + testFile := "./test/stream_tmp.json" + err = resp.Save(testFile) + require.NoError(t, err) + + defer func() { + _, statErr := os.Stat(testFile) + require.NoError(t, statErr) + + statErr = os.RemoveAll("./test") + require.NoError(t, statErr) + }() + + file, err := os.Open(testFile) + require.NoError(t, err) + defer func(file *os.File) { + closeErr := file.Close() + require.NoError(t, closeErr) + }(file) + + data, err := io.ReadAll(file) + require.NoError(t, err) + require.Equal(t, "streaming file content", string(data)) +} + +func Test_Response_Save_FileError(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/test", func(c fiber.Ctx) error { + return c.SendString("test content") + }) + }) + defer server.stop() + + client := New().SetDial(server.dial()) + + resp, err := client.Get("http://example.com/test") + require.NoError(t, err) + defer resp.Close() + + // Try to save to an invalid path (directory doesn't exist and can't be created) + err = resp.Save("/root/invalid_path/file.txt") + require.Error(t, err) + require.Contains(t, err.Error(), "failed to") +} + +func Test_Response_AcquireRelease(t *testing.T) { + t.Parallel() + + resp := AcquireResponse() + require.NotNil(t, resp) + + // Set some values + resp.RawResponse.SetStatusCode(200) + resp.RawResponse.SetBodyString("test body") + require.Equal(t, 200, resp.RawResponse.StatusCode()) + require.Equal(t, "test body", string(resp.RawResponse.Body())) + + // Release should reset the response + ReleaseResponse(resp) + + // After release, the response should be reset + // Note: We can't test the exact same object after release as it's unsafe + // So we just acquire a new one to verify the pool works + resp2 := AcquireResponse() + require.NotNil(t, resp2) + ReleaseResponse(resp2) +} + +func Test_Response_Save_FileWriteError(t *testing.T) { + t.Parallel() + + t.Run("streaming response", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/stream", func(c fiber.Ctx) error { + c.Set("Transfer-Encoding", "chunked") + return c.SendStreamWriter(func(w *bufio.Writer) { + data := []byte(strings.Repeat("streaming content ", 1000)) + if _, err := w.Write(data); err != nil { + return + } + if err := w.Flush(); err != nil { + return + } + }) + }) + }) + defer server.stop() + + client := New().SetDial(server.dial()).SetStreamResponseBody(true) + + resp, err := client.Get("http://example.com/stream") + require.NoError(t, err) + defer resp.Close() + + err = resp.Save("/dev/full") + if err != nil { + require.Contains(t, err.Error(), "failed to write response body to file") + } + }) } diff --git a/client/transport.go b/client/transport.go index a77ac0d95d2..00dd3a25ecd 100644 --- a/client/transport.go +++ b/client/transport.go @@ -27,6 +27,8 @@ type httpClientTransport interface { SetTLSConfig(config *tls.Config) SetDial(dial fasthttp.DialFunc) Client() any + StreamResponseBody() bool + SetStreamResponseBody(enable bool) } // standardClientTransport adapts fasthttp.Client to the httpClientTransport @@ -75,6 +77,14 @@ func (s *standardClientTransport) Client() any { return s.client } +func (s *standardClientTransport) StreamResponseBody() bool { + return s.client.StreamResponseBody +} + +func (s *standardClientTransport) SetStreamResponseBody(enable bool) { + s.client.StreamResponseBody = enable +} + // hostClientTransport adapts fasthttp.HostClient to the httpClientTransport // interface used by Fiber's client helpers. type hostClientTransport struct { @@ -121,6 +131,14 @@ func (h *hostClientTransport) Client() any { return h.client } +func (h *hostClientTransport) StreamResponseBody() bool { + return h.client.StreamResponseBody +} + +func (h *hostClientTransport) SetStreamResponseBody(enable bool) { + h.client.StreamResponseBody = enable +} + // lbClientTransport adapts fasthttp.LBClient to the httpClientTransport // interface used by Fiber's client helpers. type lbClientTransport struct { @@ -179,6 +197,29 @@ func (l *lbClientTransport) Client() any { return l.client } +func (l *lbClientTransport) StreamResponseBody() bool { + if len(l.client.Clients) == 0 { + return false + } + // Return the StreamResponseBody setting from the first HostClient + var streamEnabled bool + for _, c := range l.client.Clients { + if walkBalancingClientWithBreak(c, func(hc *fasthttp.HostClient) bool { + streamEnabled = hc.StreamResponseBody + return true + }) { + break + } + } + return streamEnabled +} + +func (l *lbClientTransport) SetStreamResponseBody(enable bool) { + forEachHostClient(l.client, func(hc *fasthttp.HostClient) { + hc.StreamResponseBody = enable + }) +} + // forEachHostClient applies fn to every host client reachable from the provided // load balancer by recursively following nested balancers and wrapper types. func forEachHostClient(lb *fasthttp.LBClient, fn func(*fasthttp.HostClient)) { diff --git a/client/transport_test.go b/client/transport_test.go index fe380f1fdfd..231ba7dcfe2 100644 --- a/client/transport_test.go +++ b/client/transport_test.go @@ -346,3 +346,203 @@ func TestDoRedirectsWithClientDefaultLimit(t *testing.T) { require.ErrorIs(t, err, fasthttp.ErrTooManyRedirects) require.Equal(t, defaultRedirectLimit+1, client.CallCount()) } + +func Test_StandardClientTransport_StreamResponseBody(t *testing.T) { + t.Parallel() + + t.Run("default value", func(t *testing.T) { + t.Parallel() + transport := newStandardClientTransport(&fasthttp.Client{}) + require.False(t, transport.StreamResponseBody()) + }) + + t.Run("enable streaming", func(t *testing.T) { + t.Parallel() + client := &fasthttp.Client{} + transport := newStandardClientTransport(client) + transport.SetStreamResponseBody(true) + require.True(t, transport.StreamResponseBody()) + require.True(t, client.StreamResponseBody) + }) + + t.Run("disable streaming", func(t *testing.T) { + t.Parallel() + client := &fasthttp.Client{} + transport := newStandardClientTransport(client) + transport.SetStreamResponseBody(true) + require.True(t, transport.StreamResponseBody()) + transport.SetStreamResponseBody(false) + require.False(t, transport.StreamResponseBody()) + require.False(t, client.StreamResponseBody) + }) +} + +func Test_HostClientTransport_StreamResponseBody(t *testing.T) { + t.Parallel() + + t.Run("default value", func(t *testing.T) { + t.Parallel() + hostClient := &fasthttp.HostClient{} + transport := newHostClientTransport(hostClient) + require.False(t, transport.StreamResponseBody()) + }) + + t.Run("enable streaming", func(t *testing.T) { + t.Parallel() + hostClient := &fasthttp.HostClient{} + transport := newHostClientTransport(hostClient) + transport.SetStreamResponseBody(true) + require.True(t, transport.StreamResponseBody()) + require.True(t, hostClient.StreamResponseBody) + }) + + t.Run("disable streaming", func(t *testing.T) { + t.Parallel() + hostClient := &fasthttp.HostClient{} + transport := newHostClientTransport(hostClient) + transport.SetStreamResponseBody(true) + require.True(t, transport.StreamResponseBody()) + transport.SetStreamResponseBody(false) + require.False(t, transport.StreamResponseBody()) + require.False(t, hostClient.StreamResponseBody) + }) +} + +func Test_LBClientTransport_StreamResponseBody(t *testing.T) { + t.Parallel() + + t.Run("empty clients", func(t *testing.T) { + t.Parallel() + lbClient := &fasthttp.LBClient{ + Clients: []fasthttp.BalancingClient{}, + } + transport := newLBClientTransport(lbClient) + require.False(t, transport.StreamResponseBody()) + }) + + t.Run("single host client", func(t *testing.T) { + t.Parallel() + hostClient := &fasthttp.HostClient{Addr: "example.com:80"} + lbClient := &fasthttp.LBClient{ + Clients: []fasthttp.BalancingClient{hostClient}, + } + transport := newLBClientTransport(lbClient) + + // Test default + require.False(t, transport.StreamResponseBody()) + + // Enable streaming + transport.SetStreamResponseBody(true) + require.True(t, transport.StreamResponseBody()) + require.True(t, hostClient.StreamResponseBody) + + // Disable streaming + transport.SetStreamResponseBody(false) + require.False(t, transport.StreamResponseBody()) + require.False(t, hostClient.StreamResponseBody) + }) + + t.Run("multiple host clients", func(t *testing.T) { + t.Parallel() + hostClient1 := &fasthttp.HostClient{Addr: "example1.com:80"} + hostClient2 := &fasthttp.HostClient{Addr: "example2.com:80"} + lbClient := &fasthttp.LBClient{ + Clients: []fasthttp.BalancingClient{hostClient1, hostClient2}, + } + transport := newLBClientTransport(lbClient) + + // Enable streaming on all clients + transport.SetStreamResponseBody(true) + require.True(t, transport.StreamResponseBody()) + require.True(t, hostClient1.StreamResponseBody) + require.True(t, hostClient2.StreamResponseBody) + + // Disable streaming on all clients + transport.SetStreamResponseBody(false) + require.False(t, transport.StreamResponseBody()) + require.False(t, hostClient1.StreamResponseBody) + require.False(t, hostClient2.StreamResponseBody) + }) + + t.Run("nested lb client", func(t *testing.T) { + t.Parallel() + hostClient1 := &fasthttp.HostClient{Addr: "example1.com:80"} + hostClient2 := &fasthttp.HostClient{Addr: "example2.com:80"} + nestedLB := &fasthttp.LBClient{ + Clients: []fasthttp.BalancingClient{hostClient1, hostClient2}, + } + lbClient := &fasthttp.LBClient{ + Clients: []fasthttp.BalancingClient{&lbBalancingClient{client: nestedLB}}, + } + transport := newLBClientTransport(lbClient) + + // Enable streaming on nested clients + transport.SetStreamResponseBody(true) + require.True(t, hostClient1.StreamResponseBody) + require.True(t, hostClient2.StreamResponseBody) + + // Disable streaming on nested clients + transport.SetStreamResponseBody(false) + require.False(t, hostClient1.StreamResponseBody) + require.False(t, hostClient2.StreamResponseBody) + }) + + t.Run("mixed clients with stub", func(t *testing.T) { + t.Parallel() + hostClient := &fasthttp.HostClient{Addr: "example.com:80"} + lbClient := &fasthttp.LBClient{ + Clients: []fasthttp.BalancingClient{hostClient, stubBalancingClient{}}, + } + transport := newLBClientTransport(lbClient) + + // Enable streaming + transport.SetStreamResponseBody(true) + require.True(t, transport.StreamResponseBody()) + require.True(t, hostClient.StreamResponseBody) + + // Disable streaming + transport.SetStreamResponseBody(false) + require.False(t, transport.StreamResponseBody()) + require.False(t, hostClient.StreamResponseBody) + }) +} + +func Test_httpClientTransport_Interface(t *testing.T) { + t.Parallel() + + transports := []struct { + transport httpClientTransport + name string + }{ + { + name: "standardClientTransport", + transport: newStandardClientTransport(&fasthttp.Client{}), + }, + { + name: "hostClientTransport", + transport: newHostClientTransport(&fasthttp.HostClient{}), + }, + { + name: "lbClientTransport", + transport: newLBClientTransport(&fasthttp.LBClient{ + Clients: []fasthttp.BalancingClient{ + &fasthttp.HostClient{Addr: "example.com:80"}, + }, + }), + }, + } + + for _, tt := range transports { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + transport := tt.transport + require.NotNil(t, transport.Client()) + initialStream := transport.StreamResponseBody() + transport.SetStreamResponseBody(!initialStream) + require.Equal(t, !initialStream, transport.StreamResponseBody()) + transport.SetStreamResponseBody(initialStream) + require.Equal(t, initialStream, transport.StreamResponseBody()) + }) + } +} diff --git a/ctx_test.go b/ctx_test.go index c39decc3ccc..73d2c5cf6f6 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -3775,7 +3775,7 @@ func Test_Ctx_SaveFile(t *testing.T) { req.Header.Set("Content-Type", writer.FormDataContentType()) req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes()))) - resp, err := app.Test(req) + resp, err := app.Test(req, TestConfig{Timeout: 5 * time.Second}) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } @@ -3817,7 +3817,7 @@ func Test_Ctx_SaveFileToStorage(t *testing.T) { req.Header.Set("Content-Type", writer.FormDataContentType()) req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes()))) - resp, err := app.Test(req) + resp, err := app.Test(req, TestConfig{Timeout: 5 * time.Second}) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } diff --git a/docs/client/rest.md b/docs/client/rest.md index 1ad85910db6..f66b52b001d 100644 --- a/docs/client/rest.md +++ b/docs/client/rest.md @@ -406,6 +406,22 @@ Sets a proxy URL for the client. All subsequent requests will use this proxy. func (c *Client) SetProxyURL(proxyURL string) error ``` +## StreamResponseBody + +Returns whether response body streaming is enabled. When enabled, the response body is not fully loaded into memory and can be read as a stream using `BodyStream()`. This is useful for handling large responses or server-sent events. + +```go title="Signature" +func (c *Client) StreamResponseBody() bool +``` + +## SetStreamResponseBody + +Sets whether the response body should be streamed directly to the caller instead of being fully buffered in memory. This is useful for downloading large files or handling streaming responses. + +```go title="Signature" +func (c *Client) SetStreamResponseBody(stream bool) *Client +``` + ## RetryConfig Returns the retry configuration of the client.