Skip to content

Commit 0b9c5af

Browse files
committed
feat: restrict access to all tables by default
for #10
1 parent 1511b88 commit 0b9c5af

5 files changed

+116
-14
lines changed

fixture_test.go

+5
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import (
2323
"k8s.io/klog/v2/ktesting"
2424
)
2525

26+
var enabledTestTables = []string{"test", "test_view"}
27+
2628
type TestContext struct {
2729
server *httptest.Server
2830
db *sqlx.DB
@@ -135,6 +137,7 @@ func createTestContextUsingInMemoryDB(t testing.TB) *TestContext {
135137
Execer: db,
136138
}
137139
serverOpts.AuthOptions.disableAuth = true
140+
serverOpts.SecurityOptions.EnabledTableOrViews = enabledTestTables
138141
server, err := NewServer(serverOpts)
139142
if err != nil {
140143
t.Fatal(err)
@@ -190,6 +193,7 @@ func createTestContextWithHMACTokenAuth(t testing.TB) *TestContext {
190193
Execer: db,
191194
}
192195
serverOpts.AuthOptions.TokenFilePath = testTokenFile
196+
serverOpts.SecurityOptions.EnabledTableOrViews = enabledTestTables
193197
server, err := NewServer(serverOpts)
194198
if err != nil {
195199
t.Fatal(err)
@@ -265,6 +269,7 @@ func createTestContextWithRSATokenAuth(t testing.TB) *TestContext {
265269
Execer: db,
266270
}
267271
serverOpts.AuthOptions.RSAPublicKeyFilePath = testTokenFile
272+
serverOpts.SecurityOptions.EnabledTableOrViews = enabledTestTables
268273
server, err := NewServer(serverOpts)
269274
if err != nil {
270275
t.Fatal(err)

integration_security_test.go

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package main
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestSecurityNegativeCases(t *testing.T) {
10+
t.Run("Unauthorized", func(t *testing.T) {
11+
tc := createTestContextWithHMACTokenAuth(t)
12+
defer tc.CleanUp(t)
13+
14+
tc.authToken = "" // disable auth
15+
client := tc.Client()
16+
_, _, err := client.From("test").Select("id", "", false).Execute()
17+
assert.Error(t, err)
18+
assert.Contains(t, err.Error(), "Unauthorized")
19+
})
20+
21+
t.Run("TableAccessRestricted", func(t *testing.T) {
22+
tc := createTestContextWithHMACTokenAuth(t)
23+
defer tc.CleanUp(t)
24+
25+
client := tc.Client()
26+
_, _, err := client.From(tableNameMigrations).Select("id", "", false).Execute()
27+
assert.Error(t, err)
28+
assert.Contains(t, err.Error(), "Access Restricted")
29+
})
30+
}

server.go

+23-14
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,27 @@ const (
2121
)
2222

2323
type ServerOptions struct {
24-
Logger logr.Logger
25-
Addr string
26-
AuthOptions ServerAuthOptions
27-
Queryer sqlx.QueryerContext
28-
Execer sqlx.ExecerContext
24+
Logger logr.Logger
25+
Addr string
26+
AuthOptions ServerAuthOptions
27+
SecurityOptions ServerSecurityOptions
28+
Queryer sqlx.QueryerContext
29+
Execer sqlx.ExecerContext
2930
}
3031

3132
func (opts *ServerOptions) bindCLIFlags(fs *pflag.FlagSet) {
3233
fs.StringVar(&opts.Addr, "http-addr", ":8080", "server listen addr")
3334
opts.AuthOptions.bindCLIFlags(fs)
35+
opts.SecurityOptions.bindCLIFlags(fs)
3436
}
3537

3638
func (opts *ServerOptions) defaults() error {
3739
if err := opts.AuthOptions.defaults(); err != nil {
3840
return err
3941
}
42+
if err := opts.SecurityOptions.defaults(); err != nil {
43+
return err
44+
}
4045

4146
if opts.Logger.GetSink() == nil {
4247
opts.Logger = logr.Discard()
@@ -82,17 +87,21 @@ func NewServer(opts *ServerOptions) (*dbServer, error) {
8287

8388
// TODO: allow specifying cors config from cli / table
8489
serverMux.Use(cors.AllowAll().Handler)
85-
authMiddleware := opts.AuthOptions.createAuthMiddleware(rv.responseError)
8690

8791
{
88-
serverMux.With(authMiddleware).Group(func(r chi.Router) {
89-
routePattern := fmt.Sprintf("/{%s:[^/]+}", routeVarTableOrView)
90-
r.Get(routePattern, rv.handleQueryTableOrView)
91-
r.Post(routePattern, rv.handleInsertTable)
92-
r.Patch(routePattern, rv.handleUpdateTable)
93-
r.Put(routePattern, rv.handleUpdateSingleEntity)
94-
r.Delete(routePattern, rv.handleDeleteTable)
95-
})
92+
serverMux.
93+
With(
94+
opts.AuthOptions.createAuthMiddleware(rv.responseError),
95+
opts.SecurityOptions.createTableOrViewAccessCheckMiddleware(rv.responseError, routeVarTableOrView),
96+
).
97+
Group(func(r chi.Router) {
98+
routePattern := fmt.Sprintf("/{%s:[^/]+}", routeVarTableOrView)
99+
r.Get(routePattern, rv.handleQueryTableOrView)
100+
r.Post(routePattern, rv.handleInsertTable)
101+
r.Patch(routePattern, rv.handleUpdateTable)
102+
r.Put(routePattern, rv.handleUpdateSingleEntity)
103+
r.Delete(routePattern, rv.handleDeleteTable)
104+
})
96105
}
97106

98107
rv.server.Handler = serverMux

server_errors.go

+5
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ var (
4141
Message: "Unauthorized",
4242
StatusCode: http.StatusUnauthorized,
4343
}
44+
45+
ErrAccessRestricted = &ServerError{
46+
Message: "Access Restricted",
47+
StatusCode: http.StatusForbidden,
48+
}
4449
)
4550

4651
func ErrUnsupportedOperator(op string) *ServerError {

server_security.go

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
7+
"github.com/go-chi/chi/v5"
8+
"github.com/spf13/pflag"
9+
)
10+
11+
// TODO: generally speaking, we need a fine-grained RBAC system.
12+
13+
type ServerSecurityOptions struct {
14+
// EnabledTableOrViews list of table or view names that are accessible (read & write).
15+
EnabledTableOrViews []string
16+
}
17+
18+
func (opts *ServerSecurityOptions) bindCLIFlags(fs *pflag.FlagSet) {
19+
fs.StringSliceVar(
20+
&opts.EnabledTableOrViews,
21+
"--security-allow-table",
22+
[]string{},
23+
"list of table or view names that are accessible (read & write)",
24+
)
25+
}
26+
27+
func (opts *ServerSecurityOptions) defaults() error {
28+
return nil
29+
}
30+
31+
func (opts *ServerSecurityOptions) createTableOrViewAccessCheckMiddleware(
32+
responseErr func(w http.ResponseWriter, err error),
33+
routeVarTableOrView string,
34+
) func(http.Handler) http.Handler {
35+
accesibleTableOrViews := make(map[string]struct{})
36+
for _, t := range opts.EnabledTableOrViews {
37+
accesibleTableOrViews[t] = struct{}{}
38+
}
39+
fmt.Println(accesibleTableOrViews)
40+
41+
return func(next http.Handler) http.Handler {
42+
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
43+
target := chi.URLParam(req, routeVarTableOrView)
44+
45+
if _, ok := accesibleTableOrViews[target]; !ok {
46+
responseErr(w, ErrAccessRestricted)
47+
return
48+
}
49+
50+
next.ServeHTTP(w, req)
51+
})
52+
}
53+
}

0 commit comments

Comments
 (0)