diff --git a/internal/dev_server/api/delete_flag_override.go b/internal/dev_server/api/delete_flag_override.go index b438b029..6bcd8732 100644 --- a/internal/dev_server/api/delete_flag_override.go +++ b/internal/dev_server/api/delete_flag_override.go @@ -9,8 +9,7 @@ import ( ) func (s server) DeleteFlagOverride(ctx context.Context, request DeleteFlagOverrideRequestObject) (DeleteFlagOverrideResponseObject, error) { - store := model.StoreFromContext(ctx) - err := store.DeactivateOverride(ctx, request.ProjectKey, request.FlagKey) + err := model.DeleteOverride(ctx, request.ProjectKey, request.FlagKey) if err != nil { if errors.Is(err, model.ErrNotFound) { return DeleteFlagOverride404Response{}, nil diff --git a/internal/dev_server/db/sqlite.go b/internal/dev_server/db/sqlite.go index 962f174d..722dc727 100644 --- a/internal/dev_server/db/sqlite.go +++ b/internal/dev_server/db/sqlite.go @@ -324,26 +324,25 @@ func (s Sqlite) UpsertOverride(ctx context.Context, override model.Override) (mo return override, nil } -func (s Sqlite) DeactivateOverride(ctx context.Context, projectKey, flagKey string) error { - result, err := s.database.Exec(` - UPDATE overrides set active = false, version = version+1 where project_key = ? and flag_key = ? and active = true +func (s Sqlite) DeactivateOverride(ctx context.Context, projectKey, flagKey string) (int, error) { + row := s.database.QueryRowContext(ctx, ` + UPDATE overrides + set active = false, version = version+1 + where project_key = ? and flag_key = ? and active = true + returning version `, projectKey, flagKey, ) - if err != nil { - return err - } - - rowsAffected, err := result.RowsAffected() - if err != nil { - return err - } - if rowsAffected == 0 { - return model.ErrNotFound + var version int + if err := row.Scan(&version); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, errors.Wrapf(model.ErrNotFound, "no override found for flag with key, '%s', in project with key, '%s'", projectKey, flagKey) + } + return 0, err } - return nil + return version, nil } func NewSqlite(ctx context.Context, dbPath string) (Sqlite, error) { diff --git a/internal/dev_server/db/sqlite_test.go b/internal/dev_server/db/sqlite_test.go index bb77fdf8..00e35fde 100644 --- a/internal/dev_server/db/sqlite_test.go +++ b/internal/dev_server/db/sqlite_test.go @@ -302,13 +302,13 @@ func TestDBFunctions(t *testing.T) { }) t.Run("DeactivateOverride returns error when override not found", func(t *testing.T) { - err := store.DeactivateOverride(ctx, projects[0].Key, "nope") + _, err := store.DeactivateOverride(ctx, projects[0].Key, "nope") assert.ErrorIs(t, err, model.ErrNotFound) }) - t.Run("DeactivateOverride sets the override inactive", func(t *testing.T) { + t.Run("DeactivateOverride sets the override inactive and returns the current version", func(t *testing.T) { toDelete := overrides[flagKeys[0]] - err := store.DeactivateOverride(ctx, toDelete.ProjectKey, toDelete.FlagKey) + version, err := store.DeactivateOverride(ctx, toDelete.ProjectKey, toDelete.FlagKey) assert.NoError(t, err) result, err := store.GetOverridesForProject(ctx, toDelete.ProjectKey) @@ -323,6 +323,7 @@ func TestDBFunctions(t *testing.T) { found = true assert.False(t, r.Active) + assert.Equal(t, version, r.Version) } assert.True(t, found) diff --git a/internal/dev_server/model/events.go b/internal/dev_server/model/events.go index 1701322a..43ab7c78 100644 --- a/internal/dev_server/model/events.go +++ b/internal/dev_server/model/events.go @@ -1,7 +1,7 @@ package model // Event for individual flag overrides -type UpsertOverrideEvent struct { +type OverrideEvent struct { FlagKey string ProjectKey string FlagState FlagState diff --git a/internal/dev_server/model/mocks/store.go b/internal/dev_server/model/mocks/store.go index 1fd95209..37669725 100644 --- a/internal/dev_server/model/mocks/store.go +++ b/internal/dev_server/model/mocks/store.go @@ -41,11 +41,12 @@ func (m *MockStore) EXPECT() *MockStoreMockRecorder { } // DeactivateOverride mocks base method. -func (m *MockStore) DeactivateOverride(arg0 context.Context, arg1, arg2 string) error { +func (m *MockStore) DeactivateOverride(arg0 context.Context, arg1, arg2 string) (int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DeactivateOverride", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 } // DeactivateOverride indicates an expected call of DeactivateOverride. diff --git a/internal/dev_server/model/override.go b/internal/dev_server/model/override.go index f2eb4918..16eff383 100644 --- a/internal/dev_server/model/override.go +++ b/internal/dev_server/model/override.go @@ -14,14 +14,15 @@ type Override struct { Version int } -func UpsertOverride(ctx context.Context, projectKey, flagKey string, value ldvalue.Value) (Override, error) { - // TODO: validate if the flag type matches - +// getFlagStateForFlagAndProject fetches state from the store so that it can later be used to apply an override and +// construct an update. You want to call this before you write the override so that written overrides don't +// less often don't cause updates. +func getFlagStateForFlagAndProject(ctx context.Context, projectKey, flagKey string) (FlagState, error) { store := StoreFromContext(ctx) project, err := store.GetDevProject(ctx, projectKey) - if err != nil || project == nil { - return Override{}, NewError("project does not exist within dev server") + if err != nil { + return FlagState{}, err } var flagExists bool @@ -32,7 +33,15 @@ func UpsertOverride(ctx context.Context, projectKey, flagKey string, value ldval } } if !flagExists { - return Override{}, NewError("flag does not exist within dev project") + return FlagState{}, ErrNotFound + } + return project.AllFlagsState[flagKey], nil +} + +func UpsertOverride(ctx context.Context, projectKey, flagKey string, value ldvalue.Value) (Override, error) { + flagState, err := getFlagStateForFlagAndProject(ctx, projectKey, flagKey) + if err != nil { + return Override{}, err } override := Override{ @@ -43,21 +52,45 @@ func UpsertOverride(ctx context.Context, projectKey, flagKey string, value ldval Version: 1, } + store := StoreFromContext(ctx) override, err = store.UpsertOverride(ctx, override) if err != nil { return Override{}, err } - flagState := override.Apply(project.AllFlagsState[flagKey]) - GetObserversFromContext(ctx).Notify(UpsertOverrideEvent{ + GetObserversFromContext(ctx).Notify(OverrideEvent{ FlagKey: flagKey, ProjectKey: projectKey, - FlagState: flagState, + FlagState: override.Apply(flagState), }) - return override, nil } +func DeleteOverride(ctx context.Context, projectKey, flagKey string) error { + flagState, err := getFlagStateForFlagAndProject(ctx, projectKey, flagKey) + if err != nil { + return err + } + store := StoreFromContext(ctx) + version, err := store.DeactivateOverride(ctx, projectKey, flagKey) + if err != nil { + return err + } + override := Override{ + ProjectKey: projectKey, + FlagKey: flagKey, + Value: ldvalue.Null(), // since inactive, will get use the one from flagState + Active: false, + Version: version, + } + GetObserversFromContext(ctx).Notify(OverrideEvent{ + FlagKey: flagKey, + ProjectKey: projectKey, + FlagState: override.Apply(flagState), + }) + return err +} + func (o Override) Apply(state FlagState) FlagState { flagVersion := state.Version + o.Version flagValue := state.Value diff --git a/internal/dev_server/model/override_test.go b/internal/dev_server/model/override_test.go index 6b47b563..45f94294 100644 --- a/internal/dev_server/model/override_test.go +++ b/internal/dev_server/model/override_test.go @@ -13,10 +13,12 @@ import ( ) func TestUpsertOverride(t *testing.T) { + t.Parallel() ctx := context.Background() mockController := gomock.NewController(t) + defer mockController.Finish() store := mocks.NewMockStore(mockController) - projKey := "proj" + projKey := t.Name() flagKey := "flg" ldValue := ldvalue.Bool(true) override := model.Override{ @@ -45,7 +47,6 @@ func TestUpsertOverride(t *testing.T) { _, err := model.UpsertOverride(ctx, projKey, flagKey, ldValue) assert.Error(t, err) - assert.Contains(t, err.Error(), "project does not exist within dev server") }) t.Run("Returns error if flag does not exist in project", func(t *testing.T) { @@ -57,7 +58,7 @@ func TestUpsertOverride(t *testing.T) { _, err := model.UpsertOverride(ctx, projKey, flagKey, ldValue) assert.Error(t, err) - assert.Contains(t, err.Error(), "flag does not exist within dev project") + assert.ErrorIs(t, model.ErrNotFound, err) }) t.Run("store fails to upsert, returns error", func(t *testing.T) { @@ -74,7 +75,7 @@ func TestUpsertOverride(t *testing.T) { store.EXPECT().UpsertOverride(gomock.Any(), override).Return(override, nil) observer. EXPECT(). - Handle(model.UpsertOverrideEvent{ + Handle(model.OverrideEvent{ FlagKey: flagKey, ProjectKey: projKey, FlagState: model.FlagState{Value: ldvalue.Bool(true), Version: 2}, @@ -86,6 +87,63 @@ func TestUpsertOverride(t *testing.T) { }) } +func TestDeleteOverride(t *testing.T) { + t.Parallel() + ctx := context.Background() + mockController := gomock.NewController(t) + defer mockController.Finish() + store := mocks.NewMockStore(mockController) + projKey := t.Name() + flagKey := "flg" + ldValue := ldvalue.Bool(true) + + project := &model.Project{ + Key: projKey, + AllFlagsState: model.FlagsState{flagKey: model.FlagState{Value: ldvalue.Bool(false), Version: 1}}, + } + + ctx = model.ContextWithStore(ctx, store) + + observers := model.NewObservers() + observer := mocks.NewMockObserver(mockController) + + observers.RegisterObserver(observer) + ctx = model.SetObserversOnContext(ctx, observers) + + t.Run("store unable to get project, returns error", func(t *testing.T) { + store.EXPECT().GetDevProject(gomock.Any(), projKey).Return(nil, errors.New("test 2")) + + _, err := model.UpsertOverride(ctx, projKey, flagKey, ldValue) + assert.Error(t, err) + }) + + t.Run("Returns error if store errors on delete", func(t *testing.T) { + store.EXPECT().GetDevProject(gomock.Any(), projKey).Return(project, nil) + store.EXPECT().DeactivateOverride(gomock.Any(), projKey, flagKey).Return(0, errors.New("store error on deactive override")) + + err := model.DeleteOverride(ctx, projKey, flagKey) + assert.Error(t, err) + }) + + t.Run("override is applied, observers are notified", func(t *testing.T) { + store.EXPECT().GetDevProject(gomock.Any(), projKey).Return(project, nil) + store.EXPECT().DeactivateOverride(gomock.Any(), projKey, flagKey).Return(2, nil) + observer. + EXPECT(). + Handle(model.OverrideEvent{ + FlagKey: flagKey, + ProjectKey: projKey, + FlagState: model.FlagState{ + Value: ldvalue.Bool(false), + Version: 3, // override version 2 + flag version 1 + }, + }) + + err := model.DeleteOverride(ctx, projKey, flagKey) + assert.Nil(t, err) + }) +} + func TestOverrideApply(t *testing.T) { projKey := "proj" flagKey := "flg" diff --git a/internal/dev_server/model/store.go b/internal/dev_server/model/store.go index 8859b944..74326a46 100644 --- a/internal/dev_server/model/store.go +++ b/internal/dev_server/model/store.go @@ -2,10 +2,10 @@ package model import ( "context" + "errors" "net/http" "github.com/gorilla/mux" - "github.com/pkg/errors" ) type ctxKey string @@ -15,7 +15,9 @@ const ctxKeyStore = ctxKey("model.Store") //go:generate go run go.uber.org/mock/mockgen -destination mocks/store.go -package mocks . Store type Store interface { - DeactivateOverride(ctx context.Context, projectKey, flagKey string) error + // DeactivateOverride deactivates the override for the flag, returning the updated version of the override. + // ErrNotFound is returned if there isn't an override for the flag. + DeactivateOverride(ctx context.Context, projectKey, flagKey string) (int, error) GetDevProjectKeys(ctx context.Context) ([]string, error) // GetDevProject fetches the project based on the projectKey. If it doesn't exist, ErrNotFound is returned GetDevProject(ctx context.Context, projectKey string) (*Project, error) diff --git a/internal/dev_server/sdk/stream_client_flags.go b/internal/dev_server/sdk/stream_client_flags.go index 4eac5e91..633f7043 100644 --- a/internal/dev_server/sdk/stream_client_flags.go +++ b/internal/dev_server/sdk/stream_client_flags.go @@ -53,7 +53,7 @@ type clientFlagsObserver struct { func (c clientFlagsObserver) Handle(event interface{}) { log.Printf("clientFlagsObserver: handling flag state event: %v", event) switch event := event.(type) { - case model.UpsertOverrideEvent: + case model.OverrideEvent: err := SendMessage(c.updateChan, TYPE_PATCH, clientFlag{ Key: event.FlagKey, Version: event.FlagState.Version, diff --git a/internal/dev_server/sdk/stream_server_flags.go b/internal/dev_server/sdk/stream_server_flags.go index 1d3c0a4f..fa0c5592 100644 --- a/internal/dev_server/sdk/stream_server_flags.go +++ b/internal/dev_server/sdk/stream_server_flags.go @@ -54,7 +54,7 @@ type serverFlagsObserver struct { func (c serverFlagsObserver) Handle(event interface{}) { log.Printf("serverFlagsObserver: handling flag state event: %v", event) switch event := event.(type) { - case model.UpsertOverrideEvent: + case model.OverrideEvent: if event.ProjectKey != c.projectKey { return }