Skip to content

Commit 49f7316

Browse files
Nikolay TkachenkoAndrey Buryndin
Nikolay Tkachenko
authored and
Andrey Buryndin
committed
Add API-Firewall source code; Add LICENSE
1 parent da0ac29 commit 49f7316

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+17560
-60
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
# Dependency directories (remove the comment below to include it)
1515
# vendor/
1616
.DS_Store
17+
.idea/

LICENSE

+373
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package handlers
2+
3+
import (
4+
"os"
5+
6+
"github.com/sirupsen/logrus"
7+
"github.com/valyala/fasthttp"
8+
9+
"github.com/wallarm/api-firewall/internal/platform/proxy"
10+
"github.com/wallarm/api-firewall/internal/platform/web"
11+
)
12+
13+
type Health struct {
14+
Build string
15+
Logger *logrus.Logger
16+
Pool proxy.Pool
17+
}
18+
19+
// Readiness checks if the Fasthttp connection pool is ready to handle new requests.
20+
func (h Health) Readiness(ctx *fasthttp.RequestCtx) error {
21+
22+
status := "ok"
23+
statusCode := fasthttp.StatusOK
24+
25+
reverseProxy, err := h.Pool.Get()
26+
if err != nil {
27+
status = "not ready"
28+
statusCode = fasthttp.StatusInternalServerError
29+
}
30+
31+
if reverseProxy != nil {
32+
if err := h.Pool.Put(reverseProxy); err != nil {
33+
status = "not ready"
34+
statusCode = fasthttp.StatusInternalServerError
35+
}
36+
}
37+
38+
data := struct {
39+
Status string `json:"status"`
40+
}{
41+
Status: status,
42+
}
43+
44+
return web.Respond(ctx, data, statusCode)
45+
}
46+
47+
// Liveness returns simple status info if the service is alive. If the
48+
// app is deployed to a Kubernetes cluster, it will also return pod, node, and
49+
// namespace details via the Downward API. The Kubernetes environment variables
50+
// need to be set within your Pod/Deployment manifest.
51+
func (h Health) Liveness(ctx *fasthttp.RequestCtx) error {
52+
host, err := os.Hostname()
53+
if err != nil {
54+
host = "unavailable"
55+
}
56+
57+
data := struct {
58+
Status string `json:"status,omitempty"`
59+
Build string `json:"build,omitempty"`
60+
Host string `json:"host,omitempty"`
61+
Pod string `json:"pod,omitempty"`
62+
PodIP string `json:"podIP,omitempty"`
63+
Node string `json:"node,omitempty"`
64+
Namespace string `json:"namespace,omitempty"`
65+
}{
66+
Status: "up",
67+
Build: h.Build,
68+
Host: host,
69+
Pod: os.Getenv("KUBERNETES_PODNAME"),
70+
PodIP: os.Getenv("KUBERNETES_NAMESPACE_POD_IP"),
71+
Node: os.Getenv("KUBERNETES_NODENAME"),
72+
Namespace: os.Getenv("KUBERNETES_NAMESPACE"),
73+
}
74+
75+
statusCode := fasthttp.StatusOK
76+
return web.Respond(ctx, data, statusCode)
77+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
package handlers
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"errors"
7+
"fmt"
8+
"io/ioutil"
9+
"strings"
10+
"time"
11+
12+
"github.com/savsgio/gotils/strconv"
13+
"github.com/sirupsen/logrus"
14+
"github.com/valyala/fasthttp"
15+
"github.com/valyala/fastjson"
16+
17+
"github.com/wallarm/api-firewall/internal/config"
18+
"github.com/wallarm/api-firewall/internal/platform/oauth2"
19+
"github.com/wallarm/api-firewall/internal/platform/openapi3"
20+
"github.com/wallarm/api-firewall/internal/platform/openapi3filter"
21+
"github.com/wallarm/api-firewall/internal/platform/proxy"
22+
"github.com/wallarm/api-firewall/internal/platform/routers"
23+
"github.com/wallarm/api-firewall/internal/platform/web"
24+
)
25+
26+
type openapiWaf struct {
27+
route *routers.Route
28+
proxyPool proxy.Pool
29+
logger *logrus.Logger
30+
cfg *config.APIFWConfiguration
31+
pathParamLength int
32+
parserPool *fastjson.ParserPool
33+
oauthValidator oauth2.OAuth2
34+
}
35+
36+
// EXPERIMENTAL feature
37+
// returns APIFW-Validation-Status header value
38+
func getValidationHeader(ctx *fasthttp.RequestCtx, err error) *string {
39+
var reason = "unknown"
40+
41+
switch err.(type) {
42+
43+
case *openapi3filter.ResponseError:
44+
responseError, ok := err.(*openapi3filter.ResponseError)
45+
46+
if ok && responseError.Reason != "" {
47+
reason = responseError.Reason
48+
}
49+
50+
id := fmt.Sprintf("response-%d-%s", ctx.Response.StatusCode(), strings.Split(string(ctx.Response.Header.ContentType()), ";")[0])
51+
value := fmt.Sprintf("%s:%s:response", id, reason)
52+
return &value
53+
54+
case *openapi3filter.RequestError:
55+
56+
requestError, ok := err.(*openapi3filter.RequestError)
57+
if !ok {
58+
return nil
59+
}
60+
61+
if requestError.Reason != "" {
62+
reason = requestError.Reason
63+
}
64+
65+
if requestError.Parameter != nil {
66+
paramName := "request-parameter"
67+
68+
if requestError.Reason == "" {
69+
schemaError, ok := requestError.Err.(*openapi3.SchemaError)
70+
if ok && schemaError.Reason != "" {
71+
reason = schemaError.Reason
72+
}
73+
paramName = requestError.Parameter.Name
74+
}
75+
76+
value := fmt.Sprintf("request-parameter:%s:%s", reason, paramName)
77+
return &value
78+
}
79+
80+
if requestError.RequestBody != nil {
81+
id := fmt.Sprintf("request-body-%s", strings.Split(string(ctx.Request.Header.ContentType()), ";")[0])
82+
value := fmt.Sprintf("%s:%s:request-body", id, reason)
83+
return &value
84+
}
85+
case *openapi3filter.SecurityRequirementsError:
86+
87+
secSchemeName := ""
88+
for _, scheme := range err.(*openapi3filter.SecurityRequirementsError).SecurityRequirements {
89+
for key := range scheme {
90+
secSchemeName += key + ","
91+
}
92+
}
93+
94+
secErrors := ""
95+
for _, secError := range err.(*openapi3filter.SecurityRequirementsError).Errors {
96+
secErrors += secError.Error() + ","
97+
}
98+
99+
value := fmt.Sprintf("security-requirements-%s:%s:%s", strings.TrimSuffix(secSchemeName, ","), strings.TrimSuffix(secErrors, ","), strings.TrimSuffix(secSchemeName, ","))
100+
return &value
101+
}
102+
103+
return nil
104+
}
105+
106+
func (s *openapiWaf) openapiWafHandler(ctx *fasthttp.RequestCtx) error {
107+
s.logger.Debugf("New Request: #%016X : %s -> %s %s (%s)",
108+
ctx.ID(),
109+
ctx.RemoteAddr(),
110+
ctx.Request.Header.Method(), ctx.Path(),
111+
time.Since(ctx.Time()),
112+
)
113+
114+
client, err := s.proxyPool.Get()
115+
if err != nil {
116+
s.logger.Errorf("#%016X : error while proxying request: %s", ctx.ID(), strings.Replace(err.Error(), "\n", " ", -1))
117+
return web.RespondError(ctx, fasthttp.StatusServiceUnavailable, nil)
118+
}
119+
defer s.proxyPool.Put(client)
120+
121+
// Handle request if Validation Disabled for request and response
122+
if s.route == nil || (s.cfg.RequestValidation == web.ValidationDisable && s.cfg.ResponseValidation == web.ValidationDisable) {
123+
124+
if err := client.Do(&ctx.Request, &ctx.Response); err != nil {
125+
s.logger.Errorf("#%016X : error while proxying request: %s", ctx.ID(), strings.Replace(err.Error(), "\n", " ", -1))
126+
switch err {
127+
case fasthttp.ErrDialTimeout:
128+
return web.RespondError(ctx, fasthttp.StatusGatewayTimeout, nil)
129+
case fasthttp.ErrNoFreeConns:
130+
return web.RespondError(ctx, fasthttp.StatusServiceUnavailable, nil)
131+
default:
132+
return web.RespondError(ctx, fasthttp.StatusBadGateway, nil)
133+
}
134+
}
135+
136+
// check shadow api if path or method are not found and validation mode is LOG_ONLY
137+
if s.route == nil && (s.cfg.RequestValidation == web.ValidationLog || s.cfg.ResponseValidation == web.ValidationLog) {
138+
web.ShadowAPIChecks(ctx, s.logger, &s.cfg.ShadowAPI)
139+
}
140+
141+
return nil
142+
}
143+
144+
var pathParams map[string]string
145+
146+
if s.pathParamLength > 0 {
147+
pathParams = make(map[string]string, s.pathParamLength)
148+
149+
ctx.VisitUserValues(func(key []byte, value interface{}) {
150+
keyStr := strconv.B2S(key)
151+
pathParams[keyStr] = value.(string)
152+
})
153+
}
154+
155+
// Validate request
156+
requestValidationInput := &openapi3filter.RequestValidationInput{
157+
RequestCtx: ctx,
158+
PathParams: pathParams,
159+
Route: s.route,
160+
ParserJson: s.parserPool,
161+
Options: &openapi3filter.Options{
162+
AuthenticationFunc: func(ctx context.Context, input *openapi3filter.AuthenticationInput) error {
163+
switch input.SecurityScheme.Type {
164+
case "http":
165+
switch input.SecurityScheme.Scheme {
166+
case "basic":
167+
bHeader := input.RequestValidationInput.RequestCtx.Request.Header.Peek("Authorization")
168+
if bHeader == nil || !strings.HasPrefix(strings.ToLower(strconv.B2S(bHeader)), "basic ") {
169+
return errors.New("missing basic authorization header")
170+
}
171+
case "bearer":
172+
bHeader := input.RequestValidationInput.RequestCtx.Request.Header.Peek("Authorization")
173+
if bHeader == nil || !strings.HasPrefix(strings.ToLower(strconv.B2S(bHeader)), "bearer ") {
174+
return errors.New("missing bearer authorization header")
175+
}
176+
}
177+
case "oauth2", "openIdConnect":
178+
if s.oauthValidator == nil {
179+
return errors.New("oauth2 validator not configured")
180+
}
181+
if err := s.oauthValidator.Validate(ctx, string(input.RequestValidationInput.RequestCtx.Request.Header.Peek("Authorization")), input.Scopes); err != nil {
182+
return fmt.Errorf("oauth2 error: %s", err)
183+
}
184+
185+
case "apiKey":
186+
switch input.SecurityScheme.In {
187+
case "header":
188+
if input.RequestValidationInput.RequestCtx.Request.Header.Peek(input.SecurityScheme.Name) == nil {
189+
return fmt.Errorf("missing %s header", input.SecurityScheme.Name)
190+
}
191+
case "query":
192+
if input.RequestValidationInput.RequestCtx.URI().QueryArgs().Peek(input.SecurityScheme.Name) == nil {
193+
return fmt.Errorf("missing %s query parameter", input.SecurityScheme.Name)
194+
}
195+
case "cookie":
196+
if input.RequestValidationInput.RequestCtx.Request.Header.Cookie(input.SecurityScheme.Name) == nil {
197+
return fmt.Errorf("missing %s cookie", input.SecurityScheme.Name)
198+
}
199+
}
200+
}
201+
return nil
202+
},
203+
},
204+
}
205+
206+
switch s.cfg.RequestValidation {
207+
case web.ValidationBlock:
208+
if err := openapi3filter.ValidateRequest(ctx, requestValidationInput); err != nil {
209+
s.logger.Errorf("#%016X : request validation error: %s", ctx.ID(), strings.Replace(err.Error(), "\n", " ", -1))
210+
if s.cfg.AddValidationStatusHeader {
211+
if vh := getValidationHeader(ctx, err); vh != nil {
212+
s.logger.Errorf("add header %s: %s", web.ValidationStatus, *vh)
213+
ctx.Request.Header.Add(web.ValidationStatus, *vh)
214+
return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, vh)
215+
}
216+
}
217+
return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, nil)
218+
}
219+
case web.ValidationLog:
220+
if err := openapi3filter.ValidateRequest(ctx, requestValidationInput); err != nil {
221+
s.logger.Errorf("#%016X : request validation error: %s", ctx.ID(), strings.Replace(err.Error(), "\n", " ", -1))
222+
}
223+
}
224+
225+
if err := client.Do(&ctx.Request, &ctx.Response); err != nil {
226+
s.logger.Errorf("#%016X : error while proxying request: %s", ctx.ID(), strings.Replace(err.Error(), "\n", " ", -1))
227+
switch err {
228+
case fasthttp.ErrDialTimeout:
229+
return web.RespondError(ctx, fasthttp.StatusGatewayTimeout, nil)
230+
case fasthttp.ErrNoFreeConns:
231+
return web.RespondError(ctx, fasthttp.StatusServiceUnavailable, nil)
232+
default:
233+
return web.RespondError(ctx, fasthttp.StatusBadGateway, nil)
234+
}
235+
}
236+
237+
responseValidationInput := &openapi3filter.ResponseValidationInput{
238+
RequestValidationInput: requestValidationInput,
239+
Status: ctx.Response.StatusCode(),
240+
ResponseHeader: &ctx.Response.Header,
241+
Body: ioutil.NopCloser(bytes.NewReader(ctx.Response.Body())),
242+
Options: &openapi3filter.Options{
243+
ExcludeRequestBody: false,
244+
ExcludeResponseBody: false,
245+
IncludeResponseStatus: true,
246+
MultiError: false,
247+
AuthenticationFunc: nil,
248+
},
249+
}
250+
251+
// Validate response
252+
switch s.cfg.ResponseValidation {
253+
case web.ValidationBlock:
254+
if err := openapi3filter.ValidateResponse(responseValidationInput); err != nil {
255+
s.logger.Errorf("#%016X : response validation error : %s", ctx.ID(), strings.Replace(err.Error(), "\n", " ", -1))
256+
if s.cfg.AddValidationStatusHeader {
257+
if vh := getValidationHeader(ctx, err); vh != nil {
258+
s.logger.Errorf("add header %s: %s", web.ValidationStatus, *vh)
259+
ctx.Response.Header.Add(web.ValidationStatus, *vh)
260+
return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, vh)
261+
}
262+
}
263+
return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, nil)
264+
}
265+
case web.ValidationLog:
266+
if err := openapi3filter.ValidateResponse(responseValidationInput); err != nil {
267+
s.logger.Errorf("#%016X : response validation error : %s", ctx.ID(), strings.Replace(err.Error(), "\n", " ", -1))
268+
}
269+
}
270+
271+
return nil
272+
}

0 commit comments

Comments
 (0)