Skip to content

Commit

Permalink
add tests and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
piksel committed Jul 27, 2021
1 parent 7d6c118 commit d4f8791
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 63 deletions.
80 changes: 21 additions & 59 deletions pkg/common/webclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,13 @@ package webclient

import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
)

// DefaultJsonClient is the singleton instance of WebClient using http.DefaultClient
var DefaultJsonClient = NewJSONClient()

// GetJson fetches url using GET and unmarshals into the passed response using DefaultJsonClient
func GetJson(url string, response interface{}) error {
return DefaultJsonClient.Get(url, response)
}

// PostJson sends request as JSON and unmarshals the response JSON into the supplied struct using DefaultJsonClient
func PostJson(url string, request interface{}, response interface{}) error {
return DefaultJsonClient.Post(url, request, response)
}

// WebClient is a JSON wrapper around http.WebClient
// client is a wrapper around http.Client with common notification service functionality
type client struct {
headers http.Header
indent string
Expand All @@ -30,11 +17,6 @@ type client struct {
write WriterFunc
}

// SetTransport overrides the http.RoundTripper for the web client, mainly used for testing
func (c *client) SetTransport(transport http.RoundTripper) {
c.HttpClient.Transport = transport
}

// SetParser overrides the parser for the incoming response content
func (c *client) SetParser(parse ParserFunc) {
c.parse = parse
Expand All @@ -45,74 +27,54 @@ func (c *client) SetWriter(write WriterFunc) {
c.write = write
}

// NewJSONClient returns a WebClient using the default http.Client and JSON serialization
func NewJSONClient() WebClient {
var c client
c = client{
headers: http.Header{
"Content-Type": []string{JsonContentType},
},
parse: json.Unmarshal,
write: func(v interface{}) ([]byte, error) {
return json.MarshalIndent(v, "", c.indent)
},
}
return &c
}

// Headers return the default headers for requests
func (c *client) Headers() http.Header {
return c.headers
}

// Get fetches url using GET and unmarshals into the passed response
func (c *client) Get(url string, response interface{}) error {
res, err := c.HttpClient.Get(url)
if err != nil {
return err
}

return c.parseResponse(res, response)
return c.request(url, response, nil)
}

// Post sends a serialized representation of request and deserializes the result into response
func (c *client) Post(url string, request interface{}, response interface{}) error {
var err error
var body []byte

body, err = c.write(request)
body, err := c.write(request)
if err != nil {
return fmt.Errorf("error creating payload: %v", err)
}

req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body))
return c.request(url, response, bytes.NewReader(body))
}

// ErrorResponse tries to deserialize any response body into the supplied struct, returning whether successful or not
func (c *client) ErrorResponse(err error, response interface{}) bool {
jerr, isWebError := err.(ClientError)
if !isWebError {
return false
}

return c.parse([]byte(jerr.Body), response) == nil
}

func (c *client) request(url string, response interface{}, payload io.Reader) error {
req, err := http.NewRequest(http.MethodPost, url, payload)
if err != nil {
return fmt.Errorf("error creating request: %v", err)
return err
}

for key, val := range c.headers {
req.Header.Set(key, val[0])
}

var res *http.Response
res, err = c.HttpClient.Do(req)
res, err := c.HttpClient.Do(req)
if err != nil {
return fmt.Errorf("error sending payload: %v", err)
}

return c.parseResponse(res, response)
}

// ErrorResponse tries to deserialize any response body into the supplied struct, returning whether successful or not
func (c *client) ErrorResponse(err error, response interface{}) bool {
jerr, isWebError := err.(ClientError)
if !isWebError {
return false
}

return c.parse([]byte(jerr.Body), response) == nil
}

func (c *client) parseResponse(res *http.Response, response interface{}) error {
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
Expand Down
42 changes: 41 additions & 1 deletion pkg/common/webclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
. "github.com/onsi/gomega"
)

var _ = Describe("JSONClient", func() {
var _ = Describe("WebClient", func() {
var server *ghttp.Server

BeforeEach(func() {
Expand Down Expand Up @@ -48,6 +48,46 @@ var _ = Describe("JSONClient", func() {
Expect(res.Status).To(Equal("OK"))
})

It("should update the parser and writer", func() {
client := webclient.NewJSONClient()
client.SetParser(func(raw []byte, v interface{}) error {
return errors.New(`mock parser`)
})
server.AppendHandlers(ghttp.RespondWithJSONEncoded(http.StatusOK, mockResponse{Status: "OK"}))
err := client.Get(server.URL(), nil)
Expect(err).To(MatchError(`mock parser`))

client.SetWriter(func(v interface{}) ([]byte, error) {
return nil, errors.New(`mock writer`)
})
err = client.Post(server.URL(), nil, nil)
Expect(err).To(MatchError(`error creating payload: mock writer`))
})

It("should unwrap serialized error responses", func() {
client := webclient.NewJSONClient()
err := webclient.ClientError{Body: `{"Status": "BadStuff"}`}
res := &mockResponse{}
Expect(client.ErrorResponse(err, res)).To(BeTrue())
Expect(res.Status).To(Equal(`BadStuff`))
})

It("should send any additional headers that has been added", func() {
server.AppendHandlers(
ghttp.CombineHandlers(
ghttp.VerifyHeaderKV(`Authentication`, `you don't need to see my identification`),
ghttp.RespondWithJSONEncoded(http.StatusOK, mockResponse{Status: "OK"}),
),
)
client := webclient.NewJSONClient()
client.Headers().Set(`Authentication`, `you don't need to see my identification`)
res := &mockResponse{}
err := client.Get(server.URL(), &res)
Expect(server.ReceivedRequests()).Should(HaveLen(1))
Expect(err).ToNot(HaveOccurred())
Expect(res.Status).To(Equal("OK"))
})

Describe("POST", func() {
It("should de-/serialize request and response", func() {

Expand Down
37 changes: 37 additions & 0 deletions pkg/common/webclient/json.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package webclient

import (
"encoding/json"
"net/http"
)

// JsonContentType is the default mime type for JSON
const JsonContentType = "application/json"

// DefaultJsonClient is the singleton instance of WebClient using http.DefaultClient
var DefaultJsonClient = NewJSONClient()

// GetJson fetches url using GET and unmarshals into the passed response using DefaultJsonClient
func GetJson(url string, response interface{}) error {
return DefaultJsonClient.Get(url, response)
}

// PostJson sends request as JSON and unmarshals the response JSON into the supplied struct using DefaultJsonClient
func PostJson(url string, request interface{}, response interface{}) error {
return DefaultJsonClient.Post(url, request, response)
}

// NewJSONClient returns a WebClient using the default http.Client and JSON serialization
func NewJSONClient() WebClient {
var c client
c = client{
headers: http.Header{
"Content-Type": []string{JsonContentType},
},
parse: json.Unmarshal,
write: func(v interface{}) ([]byte, error) {
return json.MarshalIndent(v, "", c.indent)
},
}
return &c
}
3 changes: 0 additions & 3 deletions pkg/common/webclient/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ type WriterFunc func(v interface{}) ([]byte, error)
var _ types.TLSClient = &ClientService{}
var _ types.HTTPService = &ClientService{}

// JsonContentType is the default mime type for JSON
const JsonContentType = "application/json"

// ClientService is a Composable that adds a generic web request client to the service
type ClientService struct {
client *client
Expand Down
47 changes: 47 additions & 0 deletions pkg/common/webclient/service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package webclient_test

import (
"github.com/containrrr/shoutrrr/pkg/common/webclient"
"net/http"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)

var _ = Describe("ClientService", func() {

When("getting the web client from an empty service", func() {
It("should return an initialized web client", func() {
service := &webclient.ClientService{}
Expect(service.WebClient()).ToNot(BeNil())
})
})

When("getting the http client from an empty service", func() {
It("should return an initialized http client", func() {
service := &webclient.ClientService{}
Expect(service.HTTPClient()).ToNot(BeNil())
})
})

When("no certs have been added", func() {
It("should use nil as the certificate pool", func() {
service := &webclient.ClientService{}
tp := service.HTTPClient().Transport.(*http.Transport)
Expect(tp.TLSClientConfig.RootCAs).To(BeNil())
})
})

When("a custom cert have been added", func() {
It("should use a custom certificate pool", func() {
service := &webclient.ClientService{}

// Adding an empty cert should fail
addedOk := service.AddTrustedRootCertificate([]byte{})
Expect(addedOk).To(BeFalse())

tp := service.HTTPClient().Transport.(*http.Transport)
Expect(tp.TLSClientConfig.RootCAs).ToNot(BeNil())
})
})
})

0 comments on commit d4f8791

Please sign in to comment.