Skip to content

Commit

Permalink
Merge pull request #24 from go-chai/fix-custom-error-types
Browse files Browse the repository at this point in the history
fix: handling of custom error types
  • Loading branch information
e-nikolov authored Feb 20, 2022
2 parents f0b9ef4 + 14b0330 commit 28a8ba2
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 31 deletions.
57 changes: 37 additions & 20 deletions chai/chai.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package chai

import (
"encoding/json"
"net/http"

"encoding/json"
)

type Methoder interface {
Expand All @@ -22,17 +23,6 @@ type Handlerer interface {
Handler() any
}

type chaiErr struct {
Message string `json:"error"`
ErrorDebug string `json:"error_debug,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
StatusCode int `json:"status_code,omitempty"`
}

func (e chaiErr) Error() string {
return e.Message
}

func write(w http.ResponseWriter, code int, v any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
Expand All @@ -51,23 +41,50 @@ func writeBytes(w http.ResponseWriter, code int, bytes []byte) {

type ErrType = error

type Err error

type ErrWrap struct {
Err error
StatusCode int
Error string
}

// TODO figure out how to do this without multiple json.Marshal/Unmarshal calls
func (ew *ErrWrap) MarshalJSON() ([]byte, error) {
m := map[string]interface{}{
"error": ew.Error,
"status_code": ew.StatusCode,
}

b, err := json.Marshal(ew.Err)
if err != nil {
return nil, err
}

err = json.Unmarshal(b, &m)
if err != nil {
return nil, err
}

return json.Marshal(m)
}

type ErrorWriter interface {
WriteError(w http.ResponseWriter, code int, e ErrType)
}

type defaultErrorWriter struct{}

func (defaultErrorWriter) WriteError(w http.ResponseWriter, code int, e ErrType) {
b, err := json.Marshal(e)
if err != nil {
panic(err)
ew := &ErrWrap{
Err: e,
StatusCode: code,
Error: e.Error(),
}

if string(b) == "{}" {
b, err = json.Marshal(&chaiErr{Message: e.Error(), StatusCode: code})
if err != nil {
panic(err)
}
b, err := json.Marshal(ew)
if err != nil {
panic(err)
}

writeBytes(w, code, b)
Expand Down
155 changes: 152 additions & 3 deletions chai/chai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/go-chai/chai/chai"
"github.com/go-chai/chai/internal/tests"
"github.com/go-chai/chai/internal/tests/xrequire"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -40,7 +42,7 @@ func newReq() io.Reader {
return buf
}

func TestReqResHandler(t *testing.T) {
func TestHandlers(t *testing.T) {
tests := []struct {
name string
makeHandler func(t *testing.T) http.Handler
Expand All @@ -64,6 +66,78 @@ func TestReqResHandler(t *testing.T) {
},
response: `{"error":"zz", "status_code":500}`,
},
{
name: "req res handler with custom struct error type with a pointer receiver with no error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewReqResHandler(func(req *tests.TestRequest, w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, *tests.TestErrorPtr) {
return newRes(), http.StatusOK, nil
})
},
response: `{"foo":"f","bar":"b","test_inner_response":{"foo_foo":123,"bar_bar":12}}`,
},
{
name: "req res handler with custom struct error type with a pointer receiver with error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewReqResHandler(func(req *tests.TestRequest, w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, *tests.TestErrorPtr) {
return nil, http.StatusInternalServerError, &tests.TestErrorPtr{Message: "zz"}
})
},
response: `{"error":"zz", "message":"zz", "status_code":500}`,
},
{
name: "req res handler with custom struct error type with no error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewReqResHandler(func(req *tests.TestRequest, w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, tests.TestError) {
return newRes(), http.StatusOK, tests.TestError{}
})
},
response: `{"foo":"f","bar":"b","test_inner_response":{"foo_foo":123,"bar_bar":12}}`,
},
{
name: "req res handler with custom struct error type with error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewReqResHandler(func(req *tests.TestRequest, w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, tests.TestError) {
return nil, http.StatusInternalServerError, tests.TestError{Message: "zz"}
})
},
response: `{"error":"zz", "message":"zz", "status_code":500}`,
},
{
name: "req res handler with custom map error type with no error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewReqResHandler(func(req *tests.TestRequest, w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, tests.TestErrorMap) {
return newRes(), http.StatusOK, nil
})
},
response: `{"foo":"f","bar":"b","test_inner_response":{"foo_foo":123,"bar_bar":12}}`,
},
{
name: "req res handler with custom map error type with error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewReqResHandler(func(req *tests.TestRequest, w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, tests.TestErrorMap) {
return nil, http.StatusInternalServerError, tests.TestErrorMap{"message": "zz"}
})
},
response: `{"error":"test error map", "message":"zz", "status_code":500}`,
},
{
name: "req res handler with custom map error type with a pointer receiver with no error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewReqResHandler(func(req *tests.TestRequest, w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, *tests.TestErrorMapPtr) {
return newRes(), http.StatusOK, nil
})
},
response: `{"foo":"f","bar":"b","test_inner_response":{"foo_foo":123,"bar_bar":12}}`,
},
{
name: "req res handler with custom map error type with a pointer receiver with error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewReqResHandler(func(req *tests.TestRequest, w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, *tests.TestErrorMapPtr) {
return nil, http.StatusInternalServerError, &tests.TestErrorMapPtr{"message": "zz"}
})
},
response: `{"error":"test error map ptr", "message":"zz", "status_code":500}`,
},
{
name: "res handler",
makeHandler: func(t *testing.T) http.Handler {
Expand All @@ -74,14 +148,87 @@ func TestReqResHandler(t *testing.T) {
response: `{"foo":"f","bar":"b","test_inner_response":{"foo_foo":123,"bar_bar":12}}`,
},
{
name: "req res handler with error",
name: "res handler with error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewResHandler(func(w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, error) {
return nil, http.StatusInternalServerError, errors.New("zz")
})
},
response: `{"error":"zz", "status_code":500}`,
},
{
name: "res handler with custom struct error type with a pointer receiver with no error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewResHandler(func(w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, *tests.TestErrorPtr) {
return newRes(), http.StatusOK, nil
})
},
response: `{"foo":"f","bar":"b","test_inner_response":{"foo_foo":123,"bar_bar":12}}`,
},
{
name: "res handler with custom struct error type with a pointer receiver with error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewResHandler(func(w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, *tests.TestErrorPtr) {
return nil, http.StatusInternalServerError, &tests.TestErrorPtr{Message: "zz"}
})
},
response: `{"error":"zz", "message":"zz", "status_code":500}`,
},

{
name: "res handler with custom struct error type with no error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewResHandler(func(w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, tests.TestError) {
return newRes(), http.StatusOK, tests.TestError{}
})
},
response: `{"foo":"f","bar":"b","test_inner_response":{"foo_foo":123,"bar_bar":12}}`,
},
{
name: "res handler with custom struct error type with error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewResHandler(func(w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, tests.TestError) {
return nil, http.StatusInternalServerError, tests.TestError{Message: "zz"}
})
},
response: `{"error":"zz", "message":"zz", "status_code":500}`,
},
{
name: "res handler with custom map error type with no error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewResHandler(func(w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, tests.TestErrorMap) {
return newRes(), http.StatusOK, nil
})
},
response: `{"foo":"f","bar":"b","test_inner_response":{"foo_foo":123,"bar_bar":12}}`,
},
{
name: "res handler with custom map error type with error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewResHandler(func(w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, tests.TestErrorMap) {
return nil, http.StatusInternalServerError, tests.TestErrorMap{"message": "zz"}
})
},
response: `{"error":"test error map", "message":"zz", "status_code":500}`,
},
{
name: "res handler with custom map error type with a pointer receiver with no error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewResHandler(func(w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, *tests.TestErrorMapPtr) {
return newRes(), http.StatusOK, nil
})
},
response: `{"foo":"f","bar":"b","test_inner_response":{"foo_foo":123,"bar_bar":12}}`,
},
{
name: "res handler with custom map error type with a pointer receiver with error",
makeHandler: func(t *testing.T) http.Handler {
return chai.NewResHandler(func(w http.ResponseWriter, r *http.Request) (*tests.TestResponse, int, *tests.TestErrorMapPtr) {
return nil, http.StatusInternalServerError, &tests.TestErrorMapPtr{"message": "zz"}
})
},
response: `{"error":"test error map ptr", "message":"zz", "status_code":500}`,
},
}

for _, tt := range tests {
Expand All @@ -93,7 +240,9 @@ func TestReqResHandler(t *testing.T) {

h.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/", newReq()))

require.JSONEq(t, tt.response, w.Body.String())
fmt.Printf("%q\n", w.Body.String())

xrequire.JSONEq(t, tt.response, w.Body.String())
})
})
}
Expand Down
9 changes: 6 additions & 3 deletions chai/req_res_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package chai

import (
"encoding/json"
"errors"
"net/http"
"reflect"
)

type ReqResHandlerFunc[Req any, Res any, Err ErrType] func(Req, http.ResponseWriter, *http.Request) (Res, int, Err)
Expand All @@ -21,6 +21,10 @@ type ReqResHandler[Req any, Res any, Err ErrType] struct {
err *Err
}

func isErr[Err ErrType](err Err) bool {
return !reflect.ValueOf(&err).Elem().IsZero()
}

func (h *ReqResHandler[Req, Res, Err]) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var req *Req

Expand All @@ -30,8 +34,7 @@ func (h *ReqResHandler[Req, Res, Err]) ServeHTTP(w http.ResponseWriter, r *http.
}

res, code, err := h.f(*req, w, r)

if !errors.Is(err, nil) {
if isErr(err) {
if code == 0 {
code = http.StatusInternalServerError
}
Expand Down
6 changes: 2 additions & 4 deletions chai/res_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@ package chai

import (
"net/http"

"errors"
)

type ResHandlerFunc[Res any, Err ErrType] func(http.ResponseWriter, *http.Request) (Res, int, Err)

func NewResHandler[Res any, Err ErrType](h ResHandlerFunc[Res, Err]) *ResHandler[Res, Err] {
return &ResHandler[Res, Err]{
f: h,
f: h,
}
}

Expand All @@ -22,7 +20,7 @@ type ResHandler[Res any, Err ErrType] struct {

func (h *ResHandler[Res, Err]) ServeHTTP(w http.ResponseWriter, r *http.Request) {
res, code, err := h.f(w, r)
if !errors.Is(err, nil) {
if isErr(err) {
if code == 0 {
code = http.StatusInternalServerError
}
Expand Down
28 changes: 28 additions & 0 deletions internal/tests/testtypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,31 @@ type TestRequest struct {

TestInnerResponse TestInnerResponse `json:"test_inner_responseb"`
}

type TestError struct {
Message string `json:"message"`
}

func (e TestError) Error() string {
return e.Message
}

type TestErrorPtr struct {
Message string `json:"message"`
}

func (e *TestErrorPtr) Error() string {
return e.Message
}

type TestErrorMap map[string]string

func (e TestErrorMap) Error() string {
return "test error map"
}

type TestErrorMapPtr map[string]string

func (e *TestErrorMapPtr) Error() string {
return "test error map ptr"
}
3 changes: 2 additions & 1 deletion internal/tests/util.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package tests

import (
"encoding/json"
"io/ioutil"
"testing"

"encoding/json"

"github.com/stretchr/testify/require"
)

Expand Down
Loading

0 comments on commit 28a8ba2

Please sign in to comment.