Skip to content

Commit 31bbd1c

Browse files
authored
Simplify role fetching logic in query engine (#282)
* Simplify role fetching logic in query engine Prior implementations of the query engine fetched role information such as the owning resource ID directly from SpiceDB, as it was the only data store available. With the introduction of CRDB, that is no longer the case and the CRDB SQL table should be considered the authoritative source of most role data. This commit updates the query engine to fetch role resource owner ID and other data from the SQL DB whenever possible, getting rid of some obscure failure modes that occur when a role has no associated actions. Signed-off-by: John Schaeffer <[email protected]> * Fix error type in RBAC v2 tests As described. Signed-off-by: John Schaeffer <[email protected]> * Wrap LockRoleForUpdate in a method to return non-DB errors As described. Signed-off-by: John Schaeffer <[email protected]> * Fix incorrect error in role update test case As described. Signed-off-by: John Schaeffer <[email protected]> --------- Signed-off-by: John Schaeffer <[email protected]>
1 parent 91d9a4e commit 31bbd1c

File tree

5 files changed

+119
-110
lines changed

5 files changed

+119
-110
lines changed

internal/query/relations.go

+83-86
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ func (e *engine) UpdateRole(ctx context.Context, actor, roleResource types.Resou
408408
return types.Role{}, err
409409
}
410410

411-
err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID)
411+
err = e.lockRoleForUpdate(dbCtx, roleResource)
412412
if err != nil {
413413
sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err)
414414

@@ -913,12 +913,18 @@ func (e *engine) ListRoles(ctx context.Context, resource types.Resource) ([]type
913913

914914
// listRoleResourceActions returns all resources and action relations for the provided resource type to the provided role.
915915
// Note: The actions returned by this function are the spicedb relationship action.
916-
func (e *engine) listRoleResourceActions(ctx context.Context, role types.Resource, resTypeName string) (map[types.Resource][]string, error) {
917-
resType := e.namespace + "/" + resTypeName
916+
func (e *engine) listRoleResourceActions(ctx context.Context, role storage.Role) ([]string, error) {
917+
roleOwnerResource, err := e.NewResourceFromID(role.ResourceID)
918+
if err != nil {
919+
return nil, err
920+
}
921+
922+
resType := e.namespace + "/" + roleOwnerResource.Type
918923
roleType := e.namespace + "/role"
919924

920925
filter := &pb.RelationshipFilter{
921-
ResourceType: resType,
926+
ResourceType: resType,
927+
OptionalResourceId: roleOwnerResource.ID.String(),
922928
OptionalSubjectFilter: &pb.SubjectFilter{
923929
SubjectType: roleType,
924930
OptionalSubjectId: role.ID.String(),
@@ -933,84 +939,47 @@ func (e *engine) listRoleResourceActions(ctx context.Context, role types.Resourc
933939
return nil, err
934940
}
935941

936-
resourceIDActions := make(map[gidx.PrefixedID][]string)
942+
out := make([]string, 0, len(relationships))
937943

938944
for _, rel := range relationships {
939-
resourceID, err := gidx.Parse(rel.Resource.ObjectId)
940-
if err != nil {
941-
return nil, err
942-
}
943-
944-
resourceIDActions[resourceID] = append(resourceIDActions[resourceID], rel.Relation)
945-
}
946-
947-
resourceActions := make(map[types.Resource][]string, len(resourceIDActions))
948-
949-
for resID, actions := range resourceIDActions {
950-
res, err := e.NewResourceFromID(resID)
951-
if err != nil {
952-
return nil, err
953-
}
945+
action := relationToAction(rel.Relation)
954946

955-
resourceActions[res] = actions
947+
out = append(out, action)
956948
}
957949

958-
return resourceActions, nil
950+
return out, nil
959951
}
960952

961-
// GetRole gets the role with it's actions.
953+
// GetRole gets the given role and its actions.
962954
func (e *engine) GetRole(ctx context.Context, roleResource types.Resource) (types.Role, error) {
963-
var (
964-
resActions map[types.Resource][]string
965-
err error
966-
)
967-
968-
for _, resType := range e.schemaRoleables {
969-
resActions, err = e.listRoleResourceActions(ctx, roleResource, resType.Name)
970-
if err != nil {
971-
return types.Role{}, err
972-
}
973-
974-
// roles are only ever created for a single resource, so we can break after the first one is found.
975-
if len(resActions) != 0 {
976-
break
977-
}
955+
dbRole, err := e.getStorageRole(ctx, roleResource)
956+
if err != nil {
957+
return types.Role{}, err
978958
}
979959

980-
if len(resActions) > 1 {
981-
return types.Role{}, ErrRoleHasTooManyResources
960+
actions, err := e.listRoleResourceActions(ctx, dbRole)
961+
if err != nil {
962+
return types.Role{}, err
982963
}
983964

984-
// returns the first resources actions.
985-
for _, actions := range resActions {
986-
for i, action := range actions {
987-
actions[i] = relationToAction(action)
988-
}
989-
990-
dbRole, err := e.store.GetRoleByID(ctx, roleResource.ID)
991-
if err != nil && !errors.Is(err, storage.ErrNoRoleFound) {
992-
e.logger.Error("error while getting role", zap.Error(err))
993-
}
965+
out := types.Role{
966+
ID: roleResource.ID,
967+
Name: dbRole.Name,
968+
Actions: actions,
994969

995-
return types.Role{
996-
ID: roleResource.ID,
997-
Name: dbRole.Name,
998-
Actions: actions,
999-
1000-
ResourceID: dbRole.ResourceID,
1001-
CreatedBy: dbRole.CreatedBy,
1002-
UpdatedBy: dbRole.UpdatedBy,
1003-
CreatedAt: dbRole.CreatedAt,
1004-
UpdatedAt: dbRole.UpdatedAt,
1005-
}, nil
970+
ResourceID: dbRole.ResourceID,
971+
CreatedBy: dbRole.CreatedBy,
972+
UpdatedBy: dbRole.UpdatedBy,
973+
CreatedAt: dbRole.CreatedAt,
974+
UpdatedAt: dbRole.UpdatedAt,
1006975
}
1007976

1008-
return types.Role{}, ErrRoleNotFound
977+
return out, nil
1009978
}
1010979

1011980
// GetRoleResource gets the role's assigned resource.
1012981
func (e *engine) GetRoleResource(ctx context.Context, roleResource types.Resource) (types.Resource, error) {
1013-
dbRole, err := e.store.GetRoleByID(ctx, roleResource.ID)
982+
dbRole, err := e.getStorageRole(ctx, roleResource)
1014983
if err != nil {
1015984
return types.Resource{}, err
1016985
}
@@ -1029,7 +998,17 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er
1029998
return err
1030999
}
10311000

1032-
err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID)
1001+
dbRole, err := e.getStorageRole(ctx, roleResource)
1002+
if err != nil {
1003+
return err
1004+
}
1005+
1006+
roleOwnerResource, err := e.NewResourceFromID(dbRole.ResourceID)
1007+
if err != nil {
1008+
return err
1009+
}
1010+
1011+
err = e.lockRoleForUpdate(dbCtx, roleResource)
10331012
if err != nil {
10341013
sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err)
10351014

@@ -1041,20 +1020,11 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er
10411020
return err
10421021
}
10431022

1044-
var resActions map[types.Resource][]string
1045-
1046-
for _, resType := range e.schemaRoleables {
1047-
resActions, err = e.listRoleResourceActions(ctx, roleResource, resType.Name)
1048-
if err != nil {
1049-
logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))
1050-
1051-
return err
1052-
}
1023+
actions, err := e.listRoleResourceActions(ctx, dbRole)
1024+
if err != nil {
1025+
logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))
10531026

1054-
// roles are only ever created for a single resource, so we can break after the first one is found.
1055-
if len(resActions) != 0 {
1056-
break
1057-
}
1027+
return err
10581028
}
10591029

10601030
roleType := e.namespace + "/role"
@@ -1069,15 +1039,16 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er
10691039
},
10701040
}
10711041

1072-
for resource, relActions := range resActions {
1073-
for _, relAction := range relActions {
1074-
filters = append(filters, &pb.RelationshipFilter{
1075-
ResourceType: e.namespace + "/" + resource.Type,
1076-
OptionalResourceId: resource.ID.String(),
1077-
OptionalRelation: relAction,
1078-
OptionalSubjectFilter: roleSubjectFilter,
1079-
})
1080-
}
1042+
ownerType := e.namespace + "/" + roleOwnerResource.Type
1043+
ownerIDStr := roleOwnerResource.ID.String()
1044+
1045+
for _, relAction := range actions {
1046+
filters = append(filters, &pb.RelationshipFilter{
1047+
ResourceType: ownerType,
1048+
OptionalResourceId: ownerIDStr,
1049+
OptionalRelation: relAction,
1050+
OptionalSubjectFilter: roleSubjectFilter,
1051+
})
10811052
}
10821053

10831054
_, err = e.store.DeleteRole(dbCtx, roleResource.ID)
@@ -1229,3 +1200,29 @@ func (e *engine) applyUpdates(ctx context.Context, updates []*pb.RelationshipUpd
12291200

12301201
return nil
12311202
}
1203+
1204+
func (e *engine) getStorageRole(ctx context.Context, roleResource types.Resource) (storage.Role, error) {
1205+
dbRole, err := e.store.GetRoleByID(ctx, roleResource.ID)
1206+
1207+
switch {
1208+
case err == nil:
1209+
return dbRole, nil
1210+
case errors.Is(err, storage.ErrNoRoleFound):
1211+
return storage.Role{}, ErrRoleNotFound
1212+
default:
1213+
return storage.Role{}, err
1214+
}
1215+
}
1216+
1217+
func (e *engine) lockRoleForUpdate(ctx context.Context, roleResource types.Resource) error {
1218+
err := e.store.LockRoleForUpdate(ctx, roleResource.ID)
1219+
1220+
switch {
1221+
case err == nil:
1222+
return nil
1223+
case errors.Is(err, storage.ErrNoRoleFound):
1224+
return ErrRoleNotFound
1225+
default:
1226+
return err
1227+
}
1228+
}

internal/query/relations_test.go

+27-14
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313

1414
"go.infratographer.com/permissions-api/internal/iapl"
1515
"go.infratographer.com/permissions-api/internal/spicedbx"
16-
"go.infratographer.com/permissions-api/internal/storage"
1716
"go.infratographer.com/permissions-api/internal/storage/teststore"
1817
"go.infratographer.com/permissions-api/internal/testingx"
1918
"go.infratographer.com/permissions-api/internal/types"
@@ -96,54 +95,68 @@ func TestCreateRoles(t *testing.T) {
9695
ctx := context.Background()
9796
e := testEngine(ctx, t, namespace, testPolicy())
9897

99-
testCases := []testingx.TestCase[[]string, []types.Role]{
98+
testCases := []testingx.TestCase[[]string, types.Role]{
10099
{
101100
Name: "CreateInvalidAction",
102101
Input: []string{
103102
"bad_action",
104103
},
105-
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[[]types.Role]) {
104+
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) {
106105
assert.Error(t, res.Err)
107106
},
108107
},
108+
{
109+
Name: "CreateNoActions",
110+
Input: []string{},
111+
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) {
112+
expActions := []string{}
113+
114+
require.NoError(t, res.Err)
115+
116+
role := res.Success
117+
assert.Equal(t, expActions, role.Actions)
118+
},
119+
},
109120
{
110121
Name: "CreateSuccess",
111122
Input: []string{
112123
"loadbalancer_get",
113124
},
114-
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[[]types.Role]) {
125+
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) {
115126
expActions := []string{
116127
"loadbalancer_get",
117128
}
118129

119-
assert.NoError(t, res.Err)
120-
require.Equal(t, 1, len(res.Success))
130+
require.NoError(t, res.Err)
121131

122-
role := res.Success[0]
132+
role := res.Success
123133
assert.Equal(t, expActions, role.Actions)
124134
},
125135
},
126136
}
127137

128-
testFn := func(ctx context.Context, actions []string) testingx.TestResult[[]types.Role] {
138+
testFn := func(ctx context.Context, actions []string) testingx.TestResult[types.Role] {
129139
tenID, err := gidx.NewID("tnntten")
130140
require.NoError(t, err)
131141
tenRes, err := e.NewResourceFromID(tenID)
132142
require.NoError(t, err)
133143
actorRes, err := e.NewResourceFromID(gidx.MustNewID("idntusr"))
134144
require.NoError(t, err)
135145

136-
_, err = e.CreateRole(ctx, actorRes, tenRes, "test", actions)
146+
role, err := e.CreateRole(ctx, actorRes, tenRes, "test", actions)
137147
if err != nil {
138-
return testingx.TestResult[[]types.Role]{
148+
return testingx.TestResult[types.Role]{
139149
Err: err,
140150
}
141151
}
142152

143-
roles, err := e.ListRoles(ctx, tenRes)
153+
roleResource, err := e.NewResourceFromID(role.ID)
154+
require.NoError(t, err)
144155

145-
return testingx.TestResult[[]types.Role]{
146-
Success: roles,
156+
obs, err := e.GetRole(ctx, roleResource)
157+
158+
return testingx.TestResult[types.Role]{
159+
Success: obs,
147160
Err: err,
148161
}
149162
}
@@ -232,7 +245,7 @@ func TestRoleUpdate(t *testing.T) {
232245
Input: gidx.MustNewID(RolePrefix),
233246
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) {
234247
require.Error(t, res.Err)
235-
assert.ErrorIs(t, res.Err, storage.ErrNoRoleFound)
248+
assert.ErrorIs(t, res.Err, ErrRoleNotFound)
236249
},
237250
},
238251
{

internal/query/roles_v2.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ func (e *engine) GetRoleV2(ctx context.Context, role types.Resource) (types.Role
202202
}
203203

204204
// 2. Get role info (name, created_by, etc.) from permissions API DB
205-
dbrole, err := e.store.GetRoleByID(ctx, role.ID)
205+
dbrole, err := e.getStorageRole(ctx, role)
206206
if err != nil {
207207
span.RecordError(err)
208208
span.SetStatus(codes.Error, err.Error())
@@ -234,7 +234,7 @@ func (e *engine) UpdateRoleV2(ctx context.Context, actor, roleResource types.Res
234234
return types.Role{}, err
235235
}
236236

237-
err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID)
237+
err = e.lockRoleForUpdate(dbCtx, roleResource)
238238
if err != nil {
239239
sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err)
240240

@@ -360,7 +360,7 @@ func (e *engine) DeleteRoleV2(ctx context.Context, roleResource types.Resource)
360360
return err
361361
}
362362

363-
err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID)
363+
err = e.lockRoleForUpdate(dbCtx, roleResource)
364364
if err != nil {
365365
sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err)
366366

@@ -399,7 +399,7 @@ func (e *engine) DeleteRoleV2(ctx context.Context, roleResource types.Resource)
399399
return err
400400
}
401401

402-
dbRole, err := e.store.GetRoleByID(dbCtx, roleResource.ID)
402+
dbRole, err := e.getStorageRole(dbCtx, roleResource)
403403
if err != nil {
404404
span.RecordError(err)
405405
span.SetStatus(codes.Error, err.Error())

0 commit comments

Comments
 (0)