diff --git a/internal/eval/compile.go b/internal/eval/compile.go index 599c349c..54448e2a 100644 --- a/internal/eval/compile.go +++ b/internal/eval/compile.go @@ -66,9 +66,9 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { case ast.ScopeTypeAll: return ast.True() case ast.ScopeTypeEq: - return ast.NewNode(varNode).Equal(ast.Value(t.Entity)) + return ast.NewNode(varNode).Equal(ast.Value(entityReferenceToUID(t.Entity))) case ast.ScopeTypeIn: - return ast.NewNode(varNode).In(ast.Value(t.Entity)) + return ast.NewNode(varNode).In(ast.Value(entityReferenceToUID(t.Entity))) case ast.ScopeTypeInSet: vals := make([]types.Value, len(t.Entities)) for i, e := range t.Entities { @@ -79,8 +79,19 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { return ast.NewNode(varNode).Is(t.Type) case ast.ScopeTypeIsIn: - return ast.NewNode(varNode).IsIn(t.Type, ast.Value(t.Entity)) + return ast.NewNode(varNode).IsIn(t.Type, ast.Value(entityReferenceToUID(t.Entity))) default: panic(fmt.Sprintf("unknown scope type %T", t)) } } + +func entityReferenceToUID(ef types.EntityReference) types.EntityUID { + switch e := ef.(type) { + case types.EntityUID: + return e + case types.SlotID: + panic("variable slot cannot be evaluated, you should instantiate a template-linked policy first") + default: + panic(fmt.Sprintf("unknown entity reference type %T", e)) + } +} diff --git a/internal/eval/partial.go b/internal/eval/partial.go index 4465b030..8b3765e4 100644 --- a/internal/eval/partial.go +++ b/internal/eval/partial.go @@ -139,14 +139,14 @@ func partialScopeEval(env Env, ent types.Value, in ast.IsScopeNode) (evaled bool case ast.ScopeTypeEq: return true, e == t.Entity case ast.ScopeTypeIn: - return true, entityInOne(env, e, t.Entity) + return true, entityInOne(env, e, entityReferenceToUID(t.Entity)) case ast.ScopeTypeInSet: set := mapset.Immutable(t.Entities...) return true, entityInSet(env, e, set) case ast.ScopeTypeIs: return true, e.Type == t.Type case ast.ScopeTypeIsIn: - return true, e.Type == t.Type && entityInOne(env, e, t.Entity) + return true, e.Type == t.Type && entityInOne(env, e, entityReferenceToUID(t.Entity)) default: panic(fmt.Sprintf("unknown scope type %T", t)) } diff --git a/internal/json/json.go b/internal/json/json.go index 55440889..9a1c0e73 100644 --- a/internal/json/json.go +++ b/internal/json/json.go @@ -12,12 +12,13 @@ type policyJSON struct { Principal scopeJSON `json:"principal"` Action scopeJSON `json:"action"` Resource scopeJSON `json:"resource"` - Conditions []conditionJSON `json:"conditions,omitempty"` + Conditions []conditionJSON `json:"conditions"` // [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html#policy-set-format } // scopeInJSON uses the implicit form of EntityUID JSON serialization to match the Rust SDK type scopeInJSON struct { - Entity types.ImplicitlyMarshaledEntityUID `json:"entity"` + Entity *types.ImplicitlyMarshaledEntityUID `json:"entity,omitempty"` + Slot *string `json:"slot,omitempty"` } // scopeJSON uses the implicit form of EntityUID JSON serialization to match the Rust SDK @@ -27,6 +28,7 @@ type scopeJSON struct { Entities []types.ImplicitlyMarshaledEntityUID `json:"entities,omitempty"` EntityType string `json:"entity_type,omitempty"` In *scopeInJSON `json:"in,omitempty"` + Slot *string `json:"slot,omitempty"` } type conditionJSON struct { diff --git a/internal/json/json_marshal.go b/internal/json/json_marshal.go index bfb5faa3..ed614cf6 100644 --- a/internal/json/json_marshal.go +++ b/internal/json/json_marshal.go @@ -15,13 +15,27 @@ func (s *scopeJSON) FromNode(src ast.IsScopeNode) { return case ast.ScopeTypeEq: s.Op = "==" - e := types.ImplicitlyMarshaledEntityUID(t.Entity) - s.Entity = &e + switch ent := t.Entity.(type) { + case types.EntityUID: + e := types.ImplicitlyMarshaledEntityUID(ent) + s.Entity = &e + case types.SlotID: + varName := ent.String() + s.Slot = &varName + } + return case ast.ScopeTypeIn: s.Op = "in" - e := types.ImplicitlyMarshaledEntityUID(t.Entity) - s.Entity = &e + switch ent := t.Entity.(type) { + case types.EntityUID: + e := types.ImplicitlyMarshaledEntityUID(ent) + s.Entity = &e + case types.SlotID: + varName := ent.String() + s.Slot = &varName + } + return case ast.ScopeTypeInSet: s.Op = "in" @@ -38,9 +52,19 @@ func (s *scopeJSON) FromNode(src ast.IsScopeNode) { case ast.ScopeTypeIsIn: s.Op = "is" s.EntityType = string(t.Type) - s.In = &scopeInJSON{ - Entity: types.ImplicitlyMarshaledEntityUID(t.Entity), + in := &scopeInJSON{} + + switch et := t.Entity.(type) { + case types.EntityUID: + uid := types.ImplicitlyMarshaledEntityUID(et) + in.Entity = &uid + case types.SlotID: + varName := et.String() + in.Slot = &varName } + + s.In = in + return default: panic(fmt.Sprintf("unknown scope type %T", t)) @@ -317,6 +341,8 @@ func (p *Policy) MarshalJSON() ([]byte, error) { j.Principal.FromNode(p.Principal) j.Action.FromNode(p.Action) j.Resource.FromNode(p.Resource) + + j.Conditions = make([]conditionJSON, 0, len(p.Conditions)) // [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html#policy-set-format for _, c := range p.Conditions { var cond conditionJSON cond.Kind = "when" diff --git a/internal/json/json_test.go b/internal/json/json_test.go index 83d3f19b..89c4a7c5 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -475,6 +475,54 @@ func TestUnmarshalJSON(t *testing.T) { ast.Permit().When(ast.ExtensionCall("ip", ast.String("10.0.0.43")).IsInRange(ast.ExtensionCall("ip", ast.String("10.0.0.42/8")))), testutil.OK, }, + { + "principal template variable", + `{"effect":"permit","principal":{"op":"==", "slot": "?principal"},"action":{"op":"All"},"resource":{"op":"All"}}`, + ast.Permit().PrincipalEq(types.PrincipalSlot), + testutil.OK, + }, + { + "principal template variable with in operator", + `{"effect":"permit","principal":{"op":"in", "slot": "?principal"},"action":{"op":"All"},"resource":{"op":"All"}}`, + ast.Permit().PrincipalIn(types.PrincipalSlot), + testutil.OK, + }, + { + "principal template variable with is in operator", + `{"effect":"permit","principal":{"op":"is", "entity_type": "User", "in": {"slot": "?principal"} },"action":{"op":"All"},"resource":{"op":"All"}}`, + ast.Permit().PrincipalIsIn("User", types.PrincipalSlot), + testutil.OK, + }, + { + "resource template variable", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"==", "slot": "?resource"}}`, + ast.Permit().ResourceEq(types.ResourceSlot), + testutil.OK, + }, + { + "resource template variable with in operator", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"in", "slot": "?resource"}}`, + ast.Permit().ResourceIn(types.ResourceSlot), + testutil.OK, + }, + { + "resource template variable with is in operator", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"is", "entity_type": "Photo", "in": {"slot": "?resource"} }}`, + ast.Permit().ResourceIsIn("Photo", types.ResourceSlot), + testutil.OK, + }, + { + "fail if entity and slot present with equal operator", + `{"effect":"permit","principal":{"op":"==", "slot": "?principal", "entity": {"type": "User", "id": "12UA45"}},"action":{"op":"All"},"resource":{"op":"All"}}`, + nil, + testutil.Error, + }, + { + "fail if entity and slot present with in operator", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"in", "slot": "?resource", "entity": {"type": "User", "id": "12UA45"}}}`, + nil, + testutil.Error, + }, } for _, tt := range tests { diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index 5abf5463..4a89dd29 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -17,27 +17,190 @@ type isPrincipalResourceScopeNode interface { ast.IsResourceScopeNode } -func (s *scopeJSON) ToPrincipalResourceNode() (isPrincipalResourceScopeNode, error) { +func slotID(id *string) (types.SlotID, error) { + sid := *id + + switch sid { + case string(types.PrincipalSlot): + return types.PrincipalSlot, nil + case string(types.ResourceSlot): + return types.ResourceSlot, nil + default: + return "", fmt.Errorf("unknown slot ID: %v", sid) + } +} + +func scopeEntityReference(s *scopeJSON) (types.EntityReference, error) { + var ref types.EntityReference + + if s.Entity != nil && s.Slot != nil { + return nil, fmt.Errorf("both entity and slot are set") + } + + if s.Entity == nil && s.Slot == nil { + return nil, fmt.Errorf("entity or slot should be set") + } + + switch { + case s.Slot != nil: + id, err := slotID(s.Slot) + if err != nil { + return nil, err + } + + ref = id + case s.Entity != nil: + ref = types.EntityUID(*s.Entity) + default: + return nil, fmt.Errorf("missing entity and slot") + } + + return ref, nil +} + +func scopeInEntityReference(s *scopeInJSON) (types.EntityReference, error) { + var ref types.EntityReference + + if s.Entity != nil && s.Slot != nil { + return nil, fmt.Errorf("both entity and slot are set") + } + + if s.Entity == nil && s.Slot == nil { + return nil, fmt.Errorf("entity or slot should be set") + } + + switch { + case s.Slot != nil: + id, err := slotID(s.Slot) + if err != nil { + return nil, err + } + + ref = id + case s.Entity != nil: + ref = types.EntityUID(*s.Entity) + default: + return nil, fmt.Errorf("missing entity and slot") + } + + return ref, nil +} + +func isSlotValid(entRef types.EntityReference, slot types.SlotID) bool { + switch v := entRef.(type) { + case types.SlotID: + return v == slot + default: + return true + } +} + +func (s *scopeJSON) ToPrincipalNode(policy *Policy, allowedSlot types.SlotID) error { switch s.Op { case "All": - return ast.Scope{}.All(), nil + return nil case "==": - if s.Entity == nil { - return nil, fmt.Errorf("missing entity") + ref, err := scopeEntityReference(s) + if err != nil { + return err } - return ast.Scope{}.Eq(types.EntityUID(*s.Entity)), nil + + if !isSlotValid(ref, allowedSlot) { + return fmt.Errorf("variable used in principal slot is not %s", allowedSlot) + } + + policy.unwrap().PrincipalEq(ref) + + return nil case "in": - if s.Entity == nil { - return nil, fmt.Errorf("missing entity") + ref, err := scopeEntityReference(s) + if err != nil { + return err + } + + if !isSlotValid(ref, allowedSlot) { + return fmt.Errorf("variable used in principal slot is not %s", allowedSlot) } - return ast.Scope{}.In(types.EntityUID(*s.Entity)), nil + + policy.unwrap().PrincipalIn(ref) + + return nil case "is": if s.In == nil { - return ast.Scope{}.Is(types.EntityType(s.EntityType)), nil + policy.unwrap().PrincipalIs(types.EntityType(s.EntityType)) + + return nil + } + + ref, err := scopeInEntityReference(s.In) + if err != nil { + return err } - return ast.Scope{}.IsIn(types.EntityType(s.EntityType), types.EntityUID(s.In.Entity)), nil + + if !isSlotValid(ref, allowedSlot) { + return fmt.Errorf("variable used in principal slot is not %s", allowedSlot) + } + + policy.unwrap().PrincipalIsIn(types.EntityType(s.EntityType), ref) + + return nil } - return nil, fmt.Errorf("unknown op: %v", s.Op) + + return fmt.Errorf("unknown op: %v", s.Op) +} + +func (s *scopeJSON) ToResourceNode(policy *Policy, allowedSlot types.SlotID) error { + switch s.Op { + case "All": + return nil + case "==": + ref, err := scopeEntityReference(s) + if err != nil { + return err + } + + if !isSlotValid(ref, allowedSlot) { + return fmt.Errorf("variable used in resource slot is not %s", allowedSlot) + } + + policy.unwrap().ResourceEq(ref) + + return nil + case "in": + ref, err := scopeEntityReference(s) + if err != nil { + return err + } + + if !isSlotValid(ref, allowedSlot) { + return fmt.Errorf("variable used in resource slot is not %s", allowedSlot) + } + + policy.unwrap().ResourceIn(ref) + + return nil + case "is": + if s.In == nil { + policy.unwrap().ResourceIs(types.EntityType(s.EntityType)) + + return nil + } + + ref, err := scopeInEntityReference(s.In) + if err != nil { + return err + } + + if !isSlotValid(ref, allowedSlot) { + return fmt.Errorf("variable used in resource slot is not %s", allowedSlot) + } + + policy.unwrap().ResourceIsIn(types.EntityType(s.EntityType), ref) + + return nil + } + + return fmt.Errorf("unknown op: %v", s.Op) } func (s *scopeJSON) ToActionNode() (ast.IsActionScopeNode, error) { @@ -306,19 +469,22 @@ func (p *Policy) UnmarshalJSON(b []byte) error { for k, v := range j.Annotations { p.unwrap().Annotate(types.Ident(k), types.String(v)) } - var err error - p.Principal, err = j.Principal.ToPrincipalResourceNode() + + err := j.Principal.ToPrincipalNode(p, types.PrincipalSlot) if err != nil { return fmt.Errorf("error in principal: %w", err) } + p.Action, err = j.Action.ToActionNode() if err != nil { return fmt.Errorf("error in action: %w", err) } - p.Resource, err = j.Resource.ToPrincipalResourceNode() + + err = j.Resource.ToResourceNode(p, types.ResourceSlot) if err != nil { return fmt.Errorf("error in resource: %w", err) } + for _, c := range j.Conditions { n, err := c.Body.ToNode() if err != nil { diff --git a/internal/json/policy_set.go b/internal/json/policy_set.go index 39ec3c5b..fb7d9122 100644 --- a/internal/json/policy_set.go +++ b/internal/json/policy_set.go @@ -1,7 +1,19 @@ package json +import "github.com/cedar-policy/cedar-go/types" + type PolicySet map[string]*Policy +type TemplateSet map[string]*Policy + +type LinkedPolicy struct { + TemplateID string `json:"templateId"` + LinkID string `json:"newId"` + Values map[string]types.ImplicitlyMarshaledEntityUID `json:"values"` +} + type PolicySetJSON struct { - StaticPolicies PolicySet `json:"staticPolicies"` + StaticPolicies PolicySet `json:"staticPolicies"` + Templates TemplateSet `json:"templates"` + TemplateLinks []LinkedPolicy `json:"templateLinks,omitempty"` } diff --git a/internal/parser/cedar_marshal.go b/internal/parser/cedar_marshal.go index d77a2b45..bb3790c3 100644 --- a/internal/parser/cedar_marshal.go +++ b/internal/parser/cedar_marshal.go @@ -2,7 +2,9 @@ package parser import ( "bytes" + "errors" "fmt" + "github.com/cedar-policy/cedar-go/types" "github.com/cedar-policy/cedar-go/internal/consts" "github.com/cedar-policy/cedar-go/internal/extensions" @@ -33,9 +35,19 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { case ast.ScopeTypeAll: return ast.True() case ast.ScopeTypeEq: - return ast.NewNode(varNode).Equal(ast.Value(t.Entity)) + rhs, err := entityReferenceToNode(t.Entity) + if err != nil { + panic(err) + } + + return ast.NewNode(varNode).Equal(rhs) case ast.ScopeTypeIn: - return ast.NewNode(varNode).In(ast.Value(t.Entity)) + rhs, err := entityReferenceToNode(t.Entity) + if err != nil { + panic(err) + } + + return ast.NewNode(varNode).In(rhs) case ast.ScopeTypeInSet: set := make([]ast.Node, len(t.Entities)) for i, e := range t.Entities { @@ -46,12 +58,28 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { return ast.NewNode(varNode).Is(t.Type) case ast.ScopeTypeIsIn: - return ast.NewNode(varNode).IsIn(t.Type, ast.Value(t.Entity)) + rhs, err := entityReferenceToNode(t.Entity) + if err != nil { + panic(err) + } + + return ast.NewNode(varNode).IsIn(t.Type, rhs) default: panic(fmt.Sprintf("unknown scope type %T", t)) } } +func entityReferenceToNode(ef types.EntityReference) (ast.Node, error) { + switch e := ef.(type) { + case types.EntityUID: + return ast.Value(e), nil + case types.SlotID: + return ast.NewNode(ast.NodeTypeVariable{Name: types.String(e.String())}), nil + default: + return ast.Node{}, errors.New("unknown entity reference type") + } +} + func (p *Policy) marshalScope(buf *bytes.Buffer) { _, principalAll := p.Principal.(ast.ScopeTypeAll) _, actionAll := p.Action.(ast.ScopeTypeAll) diff --git a/internal/parser/cedar_parse_test.go b/internal/parser/cedar_parse_test.go index f72663c8..0b3b495a 100644 --- a/internal/parser/cedar_parse_test.go +++ b/internal/parser/cedar_parse_test.go @@ -331,7 +331,7 @@ func TestParse(t *testing.T) { // N.B. Until we support the re-rendering of comments, we have to ignore the position for the purposes of // these tests (see test "ex1") - for _, pp := range policies { + for _, pp := range policies.StaticPolicies { pp.Position = ast.Position{Offset: 0, Line: 1, Column: 1} var buf bytes.Buffer @@ -341,7 +341,7 @@ func TestParse(t *testing.T) { err = p2.UnmarshalCedar(buf.Bytes()) testutil.OK(t, err) - testutil.Equals(t, p2[0], pp) + testutil.Equals(t, p2.StaticPolicies[0], pp) } }) } @@ -364,8 +364,8 @@ permit( principal, action, resource ); var out parser.PolicySlice err := out.UnmarshalCedar([]byte(in)) testutil.OK(t, err) - testutil.Equals(t, len(out), 3) - testutil.Equals(t, out[0].Position, ast.Position{Offset: 17, Line: 2, Column: 1}) - testutil.Equals(t, out[1].Position, ast.Position{Offset: 86, Line: 7, Column: 3}) - testutil.Equals(t, out[2].Position, ast.Position{Offset: 148, Line: 10, Column: 2}) + testutil.Equals(t, len(out.StaticPolicies), 3) + testutil.Equals(t, out.StaticPolicies[0].Position, ast.Position{Offset: 17, Line: 2, Column: 1}) + testutil.Equals(t, out.StaticPolicies[1].Position, ast.Position{Offset: 86, Line: 7, Column: 3}) + testutil.Equals(t, out.StaticPolicies[2].Position, ast.Position{Offset: 148, Line: 10, Column: 2}) } diff --git a/internal/parser/cedar_tokenize.go b/internal/parser/cedar_tokenize.go index c28d6ffa..38fedce1 100644 --- a/internal/parser/cedar_tokenize.go +++ b/internal/parser/cedar_tokenize.go @@ -372,7 +372,7 @@ func (s *scanner) scanComment(ch rune) rune { func (s *scanner) scanOperator(ch0, ch rune) (TokenType, rune) { switch ch0 { - case '@', '.', ',', ';', '(', ')', '{', '}', '[', ']', '+', '-', '*': + case '@', '.', ',', ';', '(', ')', '{', '}', '[', ']', '+', '-', '*', '?': case ':': if ch == ':' { ch = s.next() diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index be650d57..c8a03bbe 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -18,7 +18,9 @@ func (p *PolicySlice) UnmarshalCedar(b []byte) error { return err } - var policySet PolicySlice + var policySet []*Policy + var templateSet []*Template + parser := newParser(tokens) for !parser.peek().isEOF() { var policy Policy @@ -26,10 +28,17 @@ func (p *PolicySlice) UnmarshalCedar(b []byte) error { return err } - policySet = append(policySet, &policy) + if len(policy.unwrap().Slots()) > 0 { + t := Template(policy) + templateSet = append(templateSet, &t) + } else { + policySet = append(policySet, &policy) + } } - *p = policySet + p.StaticPolicies = policySet + p.Templates = templateSet + return nil } @@ -192,10 +201,11 @@ func (p *parser) principal(policy *ast.Policy) error { switch p.peek().Text { case "==": p.advance() - entity, err := p.entity() + entity, err := p.entityReference() if err != nil { return err } + policy.PrincipalEq(entity) return nil case "is": @@ -206,10 +216,11 @@ func (p *parser) principal(policy *ast.Policy) error { } if p.peek().Text == "in" { p.advance() - entity, err := p.entity() + entity, err := p.entityReference() if err != nil { return err } + policy.PrincipalIsIn(path, entity) return nil } @@ -218,10 +229,11 @@ func (p *parser) principal(policy *ast.Policy) error { return nil case "in": p.advance() - entity, err := p.entity() + entity, err := p.entityReference() if err != nil { return err } + policy.PrincipalIn(entity) return nil } @@ -231,10 +243,38 @@ func (p *parser) principal(policy *ast.Policy) error { func (p *parser) entity() (types.EntityUID, error) { var res types.EntityUID + + t := p.advance() + if !t.isIdent() { + return res, p.errorf("expected ident") + } + + return p.entityFirstPathPreread(types.EntityType(t.Text)) +} + +func (p *parser) entityReference() (types.EntityReference, error) { + var res types.EntityUID + + if p.peek().Type == TokenOperator && p.peek().Text == "?" { + p.advance() // consume `?` + t := p.advance() + + varName := "?" + t.Text + switch varName { + case string(types.PrincipalSlot): + return types.PrincipalSlot, nil + case string(types.ResourceSlot): + return types.ResourceSlot, nil + } + + return nil, p.errorf("unknown variable name %v", varName) + } + t := p.advance() if !t.isIdent() { return res, p.errorf("expected ident") } + return p.entityFirstPathPreread(types.EntityType(t.Text)) } @@ -348,10 +388,11 @@ func (p *parser) resource(policy *ast.Policy) error { switch p.peek().Text { case "==": p.advance() - entity, err := p.entity() + entity, err := p.entityReference() if err != nil { return err } + policy.ResourceEq(entity) return nil case "is": @@ -362,10 +403,11 @@ func (p *parser) resource(policy *ast.Policy) error { } if p.peek().Text == "in" { p.advance() - entity, err := p.entity() + entity, err := p.entityReference() if err != nil { return err } + policy.ResourceIsIn(path, entity) return nil } @@ -374,10 +416,11 @@ func (p *parser) resource(policy *ast.Policy) error { return nil case "in": p.advance() - entity, err := p.entity() + entity, err := p.entityReference() if err != nil { return err } + policy.ResourceIn(entity) return nil } diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index 1b0aeb93..c896bef7 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -471,6 +471,60 @@ when { (if true then 2 else 3 * 4) == 2 };`, when { (if true then 2 else 3) * 4 == 8 };`, ast.Permit().When(ast.IfThenElse(ast.True(), ast.Long(2), ast.Long(3)).Multiply(ast.Long(4)).Equal(ast.Long(8))), }, + { + "principal variable", + `permit ( + principal == ?principal, + action, + resource +);`, + ast.Permit().PrincipalEq(types.PrincipalSlot), + }, + { + "principal template variable with in operator", + `permit ( + principal in ?principal, + action, + resource +);`, + ast.Permit().PrincipalIn(types.PrincipalSlot), + }, + { + "principal template variable with is in operator", + `permit ( + principal is User in ?principal, + action, + resource +);`, + ast.Permit().PrincipalIsIn("User", types.PrincipalSlot), + }, + { + "resource template variable", + `permit ( + principal, + action, + resource == ?resource +);`, + ast.Permit().ResourceEq(types.ResourceSlot), + }, + { + "resource template variable with in operator", + `permit ( + principal, + action, + resource in ?resource +);`, + ast.Permit().ResourceIn(types.ResourceSlot), + }, + { + "resource template variable with is in operator", + `permit ( + principal, + action, + resource is Photo in ?resource +);`, + ast.Permit().ResourceIsIn("Photo", types.ResourceSlot), + }, } for _, tt := range parseTests { @@ -543,7 +597,7 @@ func TestParsePolicySet(t *testing.T) { expectedPolicy := ast.Permit() expectedPolicy.Position = ast.Position{Offset: 0, Line: 1, Column: 1} - testutil.Equals(t, policies[0], (*parser.Policy)(expectedPolicy)) + testutil.Equals(t, policies.StaticPolicies[0], (*parser.Policy)(expectedPolicy)) }) t.Run("two policies", func(t *testing.T) { policyStr := []byte(`permit ( @@ -561,11 +615,11 @@ func TestParsePolicySet(t *testing.T) { expectedPolicy0 := ast.Permit() expectedPolicy0.Position = ast.Position{Offset: 0, Line: 1, Column: 1} - testutil.Equals(t, policies[0], (*parser.Policy)(expectedPolicy0)) + testutil.Equals(t, policies.StaticPolicies[0], (*parser.Policy)(expectedPolicy0)) expectedPolicy1 := ast.Forbid() expectedPolicy1.Position = ast.Position{Offset: 53, Line: 6, Column: 3} - testutil.Equals(t, policies[1], (*parser.Policy)(expectedPolicy1)) + testutil.Equals(t, policies.StaticPolicies[1], (*parser.Policy)(expectedPolicy1)) }) } diff --git a/internal/parser/policy.go b/internal/parser/policy.go index b625cf13..af476bec 100644 --- a/internal/parser/policy.go +++ b/internal/parser/policy.go @@ -2,5 +2,13 @@ package parser import "github.com/cedar-policy/cedar-go/x/exp/ast" -type PolicySlice []*Policy type Policy ast.Policy + +func (p *Policy) unwrap() *ast.Policy { + return (*ast.Policy)(p) +} + +type PolicySlice struct { + StaticPolicies []*Policy + Templates []*Template +} diff --git a/internal/parser/template.go b/internal/parser/template.go new file mode 100644 index 00000000..602269d1 --- /dev/null +++ b/internal/parser/template.go @@ -0,0 +1,130 @@ +package parser + +import ( + "encoding/json" + "fmt" + "github.com/cedar-policy/cedar-go/types" + "github.com/cedar-policy/cedar-go/x/exp/ast" +) + +type Template ast.Policy + +func (p *Template) ClonePolicy() *Policy { + clone := (*ast.Policy)(p).Clone() + parserPolicy := Policy(clone) + + return &parserPolicy +} + +type LinkedPolicy struct { + TemplateID string + LinkID string + Template *Template + + slotEnv map[types.SlotID]types.EntityUID +} + +// NewLinkedPolicy creates a new instance of LinkedPolicy. +func NewLinkedPolicy(template *Template, templateID string, linkID string, slotEnv map[types.SlotID]types.EntityUID) LinkedPolicy { + return LinkedPolicy{ + Template: template, + TemplateID: templateID, + LinkID: linkID, + slotEnv: slotEnv, + } +} + +func (p LinkedPolicy) Render() (Policy, error) { + body := p.Template.ClonePolicy().unwrap() + + if len(body.Slots()) != len(p.slotEnv) { + return Policy{}, fmt.Errorf("slot env length %d does not match template slot length %d", len(p.slotEnv), len(body.Slots())) + } + + for _, slot := range body.Slots() { + switch slot { + case types.PrincipalSlot: + body.Principal = linkScope(body.Principal, p.slotEnv) + case types.ResourceSlot: + body.Resource = linkScope(body.Resource, p.slotEnv) + default: + return Policy{}, fmt.Errorf("unknown variable %s", slot) + } + } + + return Policy(*body), nil +} + +func RenderLinkedPolicy(template *Template, slotEnv map[types.SlotID]types.EntityUID) (Policy, error) { + body := template.ClonePolicy().unwrap() + + if len(body.Slots()) != len(slotEnv) { + return Policy{}, fmt.Errorf("slot env length %d does not match template slot length %d", len(slotEnv), len(body.Slots())) + } + + for _, slot := range body.Slots() { + switch slot { + case types.PrincipalSlot: + body.Principal = linkScope(body.Principal, slotEnv) + case types.ResourceSlot: + body.Resource = linkScope(body.Resource, slotEnv) + default: + return Policy{}, fmt.Errorf("unknown variable %s", slot) + } + } + + return Policy(*body), nil +} + +func linkScope[T ast.IsScopeNode](scope T, slotEnv map[types.SlotID]types.EntityUID) T { + var linkedScope any = scope + + switch t := any(scope).(type) { + case ast.ScopeTypeEq: + t.Entity = resolveSlot(t.Entity, slotEnv) + + linkedScope = t + case ast.ScopeTypeIn: + t.Entity = resolveSlot(t.Entity, slotEnv) + + linkedScope = t + case ast.ScopeTypeIsIn: + t.Entity = resolveSlot(t.Entity, slotEnv) + + linkedScope = t + default: + panic(fmt.Sprintf("unknown scope type %T", t)) + } + + return linkedScope.(T) +} + +func resolveSlot(ef types.EntityReference, slotEnv map[types.SlotID]types.EntityUID) types.EntityReference { + switch e := ef.(type) { + case types.EntityUID: + return e + case types.SlotID: + return slotEnv[e] + default: + panic(fmt.Sprintf("unknown entity reference type %T", e)) + } +} + +// MarshalJSON marshals a LinkedPolicy to JSON following cedar-cli format. +func (p LinkedPolicy) MarshalJSON() ([]byte, error) { + lp := struct { + TemplateID string `json:"template_id"` + LinkID string `json:"link_id"` + Args map[string]string `json:"args"` + }{ + TemplateID: p.TemplateID, + LinkID: p.LinkID, + } + + lp.Args = make(map[string]string, len(p.slotEnv)) + for k, v := range p.slotEnv { + lp.Args[string(k)] = v.String() + } + + return json.Marshal(lp) +} diff --git a/internal/parser/template_test.go b/internal/parser/template_test.go new file mode 100644 index 00000000..b7021a52 --- /dev/null +++ b/internal/parser/template_test.go @@ -0,0 +1,73 @@ +package parser_test + +import ( + "github.com/cedar-policy/cedar-go/internal/parser" + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" + "github.com/cedar-policy/cedar-go/x/exp/ast" + "testing" +) + +func TestLinkTemplateToPolicy(t *testing.T) { + linkTests := []struct { + Name string + TemplateString string + TemplateID string + LinkID string + Env map[types.SlotID]types.EntityUID + Want parser.Policy + }{ + { + "principal variable", + `permit ( + principal == ?principal, + action, + resource +);`, + "principal_test", + "principal_link", + map[types.SlotID]types.EntityUID{"?principal": types.NewEntityUID("User", "alice")}, + parserPolicy(ast.Permit(). + PrincipalEq(types.EntityUID{Type: "User", ID: "alice"}). + AddSlot(types.PrincipalSlot)), + }, + { + "resource variable", + `permit ( + principal, + action, + resource == ?resource +);`, + "resource_test", + "resource_link", + map[types.SlotID]types.EntityUID{"?resource": types.NewEntityUID("Album", "trip")}, + parserPolicy(ast.Permit(). + ResourceEq(types.EntityUID{Type: "Album", ID: "trip"}). + AddSlot(types.ResourceSlot)), + }, + } + + for _, tt := range linkTests { + t.Run(tt.Name, func(t *testing.T) { + t.Parallel() + + var templateBody parser.Policy + testutil.OK(t, templateBody.UnmarshalCedar([]byte(tt.TemplateString))) + template := parser.Template(templateBody) + + linkedPolicy := parser.NewLinkedPolicy(&template, tt.TemplateID, tt.LinkID, tt.Env) + + testutil.Equals(t, linkedPolicy.LinkID, tt.LinkID) + + newPolicy, err := linkedPolicy.Render() + testutil.OK(t, err) + + newPolicy.Position = ast.Position{} + testutil.Equals(t, newPolicy, tt.Want) + }) + } +} + +func parserPolicy(inAST *ast.Policy) parser.Policy { + return parser.Policy(*inAST) +} diff --git a/policy_list.go b/policy_list.go index 835d170d..cf852d06 100644 --- a/policy_list.go +++ b/policy_list.go @@ -32,8 +32,8 @@ func (p *PolicyList) UnmarshalCedar(b []byte) error { if err := res.UnmarshalCedar(b); err != nil { return fmt.Errorf("parser error: %w", err) } - policySlice := make([]*Policy, 0, len(res)) - for _, p := range res { + policySlice := make([]*Policy, 0, len(res.StaticPolicies)) + for _, p := range res.StaticPolicies { newPolicy := newPolicy((*internalast.Policy)(p)) policySlice = append(policySlice, newPolicy) } diff --git a/policy_set_test.go b/policy_set_test.go index 4138ba4d..26a89b46 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -142,7 +142,7 @@ func TestPolicySetMap(t *testing.T) { t.Parallel() ps, err := cedar.NewPolicySetFromBytes("", []byte(`permit (principal, action, resource);`)) testutil.OK(t, err) - m := ps.Map() + m := maps.Collect(ps.All()) testutil.Equals(t, len(m), 1) } @@ -159,7 +159,7 @@ func TestPolicySetJSON(t *testing.T) { var ps cedar.PolicySet err := ps.UnmarshalJSON([]byte(`{"staticPolicies":{"policy0":{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}}}`)) testutil.OK(t, err) - testutil.Equals(t, len(ps.Map()), 1) + testutil.Equals(t, len(maps.Collect(ps.All())), 1) }) t.Run("MarshalOK", func(t *testing.T) { @@ -168,7 +168,7 @@ func TestPolicySetJSON(t *testing.T) { testutil.OK(t, err) out, err := ps.MarshalJSON() testutil.OK(t, err) - testutil.Equals(t, string(out), `{"staticPolicies":{"policy0":{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}}}`) + testutil.Equals(t, string(out), `{"staticPolicies":{"policy0":{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}},"templates":{}}`) }) } diff --git a/types/entity.go b/types/entity.go index aff8b627..8b3d7506 100644 --- a/types/entity.go +++ b/types/entity.go @@ -6,6 +6,8 @@ import ( "strings" ) +const CedarVariable = EntityType("__cedar::variable") + // An Entity defines the parents and attributes for an EntityUID. type Entity struct { UID EntityUID `json:"uid"` @@ -42,3 +44,7 @@ func (e Entity) MarshalJSON() ([]byte, error) { } return json.Marshal(m) } + +type EntityReference interface { + isEntityReference() +} diff --git a/types/entity_uid.go b/types/entity_uid.go index 2ee36e2e..047cc14b 100644 --- a/types/entity_uid.go +++ b/types/entity_uid.go @@ -38,6 +38,8 @@ func (e EntityUID) Equal(bi Value) bool { return ok && e == b } +func (a EntityUID) isEntityReference() {} + // String produces a string representation of the EntityUID, e.g. `Type::"id"`. func (e EntityUID) String() string { return string(e.Type) + "::" + strconv.Quote(string(e.ID)) } diff --git a/types/template.go b/types/template.go new file mode 100644 index 00000000..e606aa28 --- /dev/null +++ b/types/template.go @@ -0,0 +1,14 @@ +package types + +type SlotID string + +const ( + PrincipalSlot SlotID = "?principal" + ResourceSlot SlotID = "?resource" +) + +func (s SlotID) String() string { + return string(s) +} + +func (s SlotID) isEntityReference() {} diff --git a/x/exp/ast/policy.go b/x/exp/ast/policy.go index 5578ce39..3074d9e5 100644 --- a/x/exp/ast/policy.go +++ b/x/exp/ast/policy.go @@ -36,6 +36,10 @@ type Position struct { Column int // column number, starting at 1 (character count per line) } +type templateContext struct { + slots []types.SlotID +} + type Policy struct { Effect Effect Annotations []AnnotationType // duplicate keys are prevented via the builders @@ -44,6 +48,8 @@ type Policy struct { Resource IsResourceScopeNode Conditions []ConditionType Position Position + + tplCtx templateContext } func newPolicy(effect Effect, annotations []AnnotationType) *Policy { @@ -73,3 +79,33 @@ func (p *Policy) Unless(node Node) *Policy { p.Conditions = append(p.Conditions, ConditionType{Condition: ConditionUnless, Body: node.v}) return p } + +func (p *Policy) addSlot(entRef types.EntityReference) *Policy { + switch v := entRef.(type) { + case types.SlotID: + p.tplCtx.slots = append(p.tplCtx.slots, v) + } + + return p +} + +func (p *Policy) Slots() []types.SlotID { + return p.tplCtx.slots +} + +func (p *Policy) Clone() Policy { + clonedPolicy := Policy{ + Effect: p.Effect, + Annotations: append([]AnnotationType(nil), p.Annotations...), + Principal: p.Principal, + Action: p.Action, + Resource: p.Resource, + Conditions: append([]ConditionType(nil), p.Conditions...), + Position: p.Position, + tplCtx: templateContext{ + slots: append([]types.SlotID(nil), p.tplCtx.slots...), + }, + } + + return clonedPolicy +} diff --git a/x/exp/ast/scope.go b/x/exp/ast/scope.go index 8d81057c..55d95bb6 100644 --- a/x/exp/ast/scope.go +++ b/x/exp/ast/scope.go @@ -10,11 +10,11 @@ func (s Scope) All() ScopeTypeAll { return ScopeTypeAll{} } -func (s Scope) Eq(entity types.EntityUID) ScopeTypeEq { +func (s Scope) Eq(entity types.EntityReference) ScopeTypeEq { return ScopeTypeEq{Entity: entity} } -func (s Scope) In(entity types.EntityUID) ScopeTypeIn { +func (s Scope) In(entity types.EntityReference) ScopeTypeIn { return ScopeTypeIn{Entity: entity} } @@ -26,28 +26,30 @@ func (s Scope) Is(entityType types.EntityType) ScopeTypeIs { return ScopeTypeIs{Type: entityType} } -func (s Scope) IsIn(entityType types.EntityType, entity types.EntityUID) ScopeTypeIsIn { +func (s Scope) IsIn(entityType types.EntityType, entity types.EntityReference) ScopeTypeIsIn { return ScopeTypeIsIn{Type: entityType, Entity: entity} } -func (p *Policy) PrincipalEq(entity types.EntityUID) *Policy { +func (p *Policy) PrincipalEq(entity types.EntityReference) *Policy { p.Principal = Scope{}.Eq(entity) - return p + return p.addSlot(entity) } -func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { +func (p *Policy) PrincipalIn(entity types.EntityReference) *Policy { p.Principal = Scope{}.In(entity) - return p + return p.addSlot(entity) } func (p *Policy) PrincipalIs(entityType types.EntityType) *Policy { p.Principal = Scope{}.Is(entityType) + return p } -func (p *Policy) PrincipalIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { +func (p *Policy) PrincipalIsIn(entityType types.EntityType, entity types.EntityReference) *Policy { p.Principal = Scope{}.IsIn(entityType, entity) - return p + + return p.addSlot(entity) } func (p *Policy) ActionEq(entity types.EntityUID) *Policy { @@ -65,14 +67,14 @@ func (p *Policy) ActionInSet(entities ...types.EntityUID) *Policy { return p } -func (p *Policy) ResourceEq(entity types.EntityUID) *Policy { +func (p *Policy) ResourceEq(entity types.EntityReference) *Policy { p.Resource = Scope{}.Eq(entity) - return p + return p.addSlot(entity) } -func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { +func (p *Policy) ResourceIn(entity types.EntityReference) *Policy { p.Resource = Scope{}.In(entity) - return p + return p.addSlot(entity) } func (p *Policy) ResourceIs(entityType types.EntityType) *Policy { @@ -80,9 +82,9 @@ func (p *Policy) ResourceIs(entityType types.EntityType) *Policy { return p } -func (p *Policy) ResourceIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { +func (p *Policy) ResourceIsIn(entityType types.EntityType, entity types.EntityReference) *Policy { p.Resource = Scope{}.IsIn(entityType, entity) - return p + return p.addSlot(entity) } type IsScopeNode interface { @@ -132,7 +134,7 @@ type ScopeTypeEq struct { PrincipalScopeNode ActionScopeNode ResourceScopeNode - Entity types.EntityUID + Entity types.EntityReference } type ScopeTypeIn struct { @@ -140,7 +142,7 @@ type ScopeTypeIn struct { PrincipalScopeNode ActionScopeNode ResourceScopeNode - Entity types.EntityUID + Entity types.EntityReference } type ScopeTypeInSet struct { @@ -161,5 +163,5 @@ type ScopeTypeIsIn struct { PrincipalScopeNode ResourceScopeNode Type types.EntityType - Entity types.EntityUID + Entity types.EntityReference } diff --git a/x/exp/templates/README.md b/x/exp/templates/README.md new file mode 100644 index 00000000..28c6e074 --- /dev/null +++ b/x/exp/templates/README.md @@ -0,0 +1,196 @@ +# Cedar Templates for Go + +Cedar Templates is a feature that extends the Cedar policy language in Go by allowing you to create policy templates with placeholder variables that can be filled in at runtime. This README explains the basics of Cedar Templates and provides examples of how to use them. + +## Overview + +Cedar policy language provides a way to define access control policies for your applications. Templates enhance this capability by allowing you to create policy patterns that can be instantiated with specific values at runtime. This is particularly useful when you need to create similar policies for different entities without duplicating policy code. + +## Key Concepts + +- **Template**: A Cedar policy with placeholders (slots) that can be filled in at runtime. +- **Slots**: Placeholders in a template denoted by a question mark followed by an identifier (e.g., `?principal`). +- **Linking**: The process of binding concrete values to slots in a template to create a usable policy. +- **PolicySet**: A collection of policies and templates that can be used for authorization decisions. + +## Basic Usage + +### Creating a Template + +A template looks like a regular Cedar policy but includes slots (marked with `?`) for values to be filled in later: + +```go +templateStr := `permit ( + principal == ?principal, + action, + resource == ?resource +) +when { resource.owner == principal };` + +var template templates.Template +err := template.UnmarshalCedar([]byte(templateStr)) +if err != nil { + // handle error +} +``` + +### Creating a PolicySet and Adding Templates + +```go +// Create a new empty PolicySet +policySet := templates.NewPolicySet() + +// Add a template to the PolicySet +templateID := cedar.PolicyID("access_template") +policySet.AddTemplate(templateID, &template) +``` + +### Linking a Template to Create a Policy + +Once you have a template, you can link it with specific entity values to create a concrete policy: + +```go +// Define the slot values +slotValues := map[types.SlotID]types.EntityUID{ + "?principal": types.NewEntityUID("User", "alice"), + "?resource": types.NewEntityUID("Document", "report"), +} + +// Link the template to create a policy +linkID := cedar.PolicyID("alice_report_access") +err = policySet.LinkTemplate(templateID, linkID, slotValues) +if err != nil { + // handle error +} +``` + +### Using Templates for Authorization + +```go +// Create a request +request := cedar.Request{ + Principal: cedar.NewEntityUID("User", "alice"), + Action: cedar.NewEntityUID("Action", "read"), + Resource: cedar.NewEntityUID("Document", "report"), + Context: types.NewRecord(nil), +} + +// Create an entity store with relevant entities +entities := types.NewEntityMap() +// Add entities to the store... + +// Make an authorization decision +decision, diagnostic := templates.Authorize(policySet, entities, request) + +// Check the decision +if decision == cedar.Allow { + // Access granted +} else { + // Access denied +} +``` + +## Advanced Examples + +### Example 1: Role-Based Access Control Template + +```go +// Template that grants access based on role +roleBasedTemplate := `permit ( + principal, + action, + resource +) +when { principal.roles.contains(?role) };` + +// Link with a specific role +roleSlots := map[types.SlotID]types.EntityUID{ + "?role": types.NewEntityUID("Role", "admin"), +} +policySet.LinkTemplate(cedar.PolicyID("role_template"), cedar.PolicyID("admin_access"), roleSlots) +``` + +### Example 2: Resource Ownership Template + +```go +// Template for resource ownership +ownershipTemplate := `permit ( + principal == ?owner, + action in [Action::"read", Action::"write", Action::"delete"], + resource == ?resource +);` + +// Link with specific owner and resource +ownershipSlots := map[types.SlotID]types.EntityUID{ + "?owner": types.NewEntityUID("User", "bob"), + "?resource": types.NewEntityUID("Photo", "vacation"), +} +policySet.LinkTemplate(cedar.PolicyID("ownership_template"), cedar.PolicyID("bob_photo_ownership"), ownershipSlots) +``` + +### Example 3: Handling Multiple Templates + +```go +// Load templates from Cedar language text +policySetStr := ` +// Resource ownership template +template ownership_tpl(principal, resource) { + permit( + principal == ?principal, + action in [Action::"read", Action::"write"], + resource == ?resource + ); +} + +// Role-based access template +template role_tpl(role) { + permit( + principal, + action, + resource + ) + when { principal.roles.contains(?role) }; +} +` + +policySet, err := templates.NewPolicySetFromBytes("policies.cedar", []byte(policySetStr)) +if err != nil { + // handle error +} + +// Link templates +policySet.LinkTemplate("ownership_tpl", "alice_doc1_ownership", map[types.SlotID]types.EntityUID{ + "?principal": types.NewEntityUID("User", "alice"), + "?resource": types.NewEntityUID("Document", "doc1"), +}) + +policySet.LinkTemplate("role_tpl", "admin_access", map[types.SlotID]types.EntityUID{ + "?role": types.NewEntityUID("Role", "admin"), +}) +``` + +## Working with Template Outputs + +After linking a template, the resulting policy can be: + +1. Used for authorization via the `templates.Authorize()` function +2. Serialized to Cedar language format with `MarshalCedar()` +3. Serialized to JSON format with `MarshalJSON()` + +## Notes and Best Practices + +1. **Template Management**: Keep track of template IDs and linked policy IDs to manage them effectively. +2. **Error Handling**: Always check for errors when parsing templates, linking them, or making authorization decisions. +3. **Entity Management**: Ensure your entity store contains all entities referenced in your policies and templates. +4. **Slot Validation**: Verify that all required slots are provided when linking a template. +5. **Experimental Status**: Note that the templates package is in the experimental (`x/exp`) namespace and may undergo changes. + +## Additional Resources + +- [Cedar Policy Documentation](https://docs.cedarpolicy.com/) +- [Cedar Templates Documentation](https://docs.cedarpolicy.com/policies/templates.html) +- [Go API Reference](https://pkg.go.dev/github.com/cedar-policy/cedar-go) + +## License + +Cedar is licensed under the Apache License, Version 2.0 \ No newline at end of file diff --git a/x/exp/templates/authorize.go b/x/exp/templates/authorize.go new file mode 100644 index 00000000..2a74099a --- /dev/null +++ b/x/exp/templates/authorize.go @@ -0,0 +1,64 @@ +package templates + +import ( + "github.com/cedar-policy/cedar-go" + "iter" + + "github.com/cedar-policy/cedar-go/internal/eval" + "github.com/cedar-policy/cedar-go/types" +) + +// PolicyIterator is an interface which abstracts an iterable set of policies. +type PolicyIterator interface { + // All returns an iterator over all the policies in the set + All() iter.Seq2[cedar.PolicyID, *Policy] +} + +// Authorize uses the combination of the PolicySet and Entities to determine +// if the given Request to determine Decision and Diagnostic. +func Authorize(policies PolicyIterator, entities types.EntityGetter, req cedar.Request) (cedar.Decision, cedar.Diagnostic) { + if entities == nil { + var zero types.EntityMap + entities = zero + } + env := eval.Env{ + Entities: entities, + Principal: req.Principal, + Action: req.Action, + Resource: req.Resource, + Context: req.Context, + } + var diag cedar.Diagnostic + var forbids []cedar.DiagnosticReason + var permits []cedar.DiagnosticReason + // Don't try to short circuit this. + // - Even though single forbid means forbid + // - All policy should be run to collect errors + // - For permit, all permits must be run to collect annotations + // - For forbid, forbids must be run to collect annotations + for id, po := range policies.All() { + result, err := po.eval.Eval(env) + if err != nil { + diag.Errors = append(diag.Errors, cedar.DiagnosticError{PolicyID: id, Position: po.Position(), Message: err.Error()}) + continue + } + if !result { + continue + } + if po.Effect() == cedar.Forbid { + forbids = append(forbids, cedar.DiagnosticReason{PolicyID: id, Position: po.Position()}) + } else { + permits = append(permits, cedar.DiagnosticReason{PolicyID: id, Position: po.Position()}) + } + } + if len(forbids) > 0 { + diag.Reasons = forbids + return cedar.Deny, diag + } + if len(permits) > 0 { + diag.Reasons = permits + return cedar.Allow, diag + } + + return cedar.Deny, diag +} diff --git a/x/exp/templates/authorize_test.go b/x/exp/templates/authorize_test.go new file mode 100644 index 00000000..66122448 --- /dev/null +++ b/x/exp/templates/authorize_test.go @@ -0,0 +1,240 @@ +package templates_test + +import ( + "testing" + + "github.com/cedar-policy/cedar-go" + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" + "github.com/cedar-policy/cedar-go/x/exp/templates" +) + +func TestIsAuthorizedFromLinkedPolicies(t *testing.T) { + t.Parallel() + cuzco := cedar.NewEntityUID("coder", "cuzco") + dropTable := cedar.NewEntityUID("table", "drop") + tests := []struct { + Name string + Policy string + LinkEnv map[types.SlotID]types.EntityUID + TemplateID cedar.PolicyID + Entities types.EntityGetter + Principal, Action, Resource cedar.EntityUID + Context cedar.Record + Want cedar.Decision + DiagErr int + ParseErr bool + LinkErr bool + }{ + { + Name: "simple-permit", + Policy: `permit(principal == ?principal,action,resource);`, + TemplateID: "template0", + LinkEnv: map[types.SlotID]types.EntityUID{"?principal": cuzco}, + Entities: cedar.EntityMap{}, + Principal: cuzco, + Action: dropTable, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, + Want: cedar.Allow, + DiagErr: 0, + }, + { + Name: "simple-forbid", + Policy: `forbid(principal == ?principal,action,resource);`, + TemplateID: "template0", + LinkEnv: map[types.SlotID]types.EntityUID{"?principal": cuzco}, + Entities: cedar.EntityMap{}, + Principal: cuzco, + Action: dropTable, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, + Want: cedar.Deny, + DiagErr: 0, + }, + { + Name: "permit-resource-equals", + Policy: `permit(principal,action,resource == ?resource);`, + TemplateID: "template0", + LinkEnv: map[types.SlotID]types.EntityUID{"?resource": cedar.NewEntityUID("table", "whatever")}, + Entities: cedar.EntityMap{}, + Principal: cuzco, + Action: dropTable, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, + Want: cedar.Allow, + DiagErr: 0, + }, + { + Name: "permit-when-in-hierarchy", + Policy: `permit(principal in ?principal,action,resource);`, + TemplateID: "template0", + LinkEnv: map[types.SlotID]types.EntityUID{"?principal": cedar.NewEntityUID("team", "osiris")}, + Entities: cedar.EntityMap{ + cuzco: cedar.Entity{ + UID: cuzco, + Parents: cedar.NewEntityUIDSet(cedar.NewEntityUID("team", "osiris")), + }, + }, + Principal: cuzco, + Action: dropTable, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, + Want: cedar.Allow, + DiagErr: 0, + }, + { + Name: "permit-when-condition", + Policy: `permit(principal == ?principal,action,resource) when { context.x == 42 };`, + TemplateID: "template0", + LinkEnv: map[types.SlotID]types.EntityUID{"?principal": cuzco}, + Entities: cedar.EntityMap{}, + Principal: cuzco, + Action: dropTable, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.NewRecord(cedar.RecordMap{"x": cedar.Long(42)}), + Want: cedar.Allow, + DiagErr: 0, + }, + { + Name: "permit-when-condition-fails", + Policy: `permit(principal == ?principal,action,resource) when { context.x == 42 };`, + TemplateID: "template0", + LinkEnv: map[types.SlotID]types.EntityUID{"?principal": cuzco}, + Entities: cedar.EntityMap{}, + Principal: cuzco, + Action: dropTable, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.NewRecord(cedar.RecordMap{"x": cedar.Long(43)}), + Want: cedar.Deny, + DiagErr: 0, + }, + { + Name: "permit-requires-entities", + Policy: `permit(principal == ?principal,action,resource) when { principal.x == 42 };`, + TemplateID: "template0", + LinkEnv: map[types.SlotID]types.EntityUID{"?principal": cuzco}, + Entities: cedar.EntityMap{ + cuzco: cedar.Entity{ + UID: cuzco, + Attributes: cedar.NewRecord(cedar.RecordMap{"x": cedar.Long(42)}), + }, + }, + Principal: cuzco, + Action: dropTable, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, + Want: cedar.Allow, + DiagErr: 0, + }, + { + Name: "multiple-slots-without-action", + Policy: `permit(principal == ?principal,action,resource == ?resource);`, + TemplateID: "template0", + LinkEnv: map[types.SlotID]types.EntityUID{ + "?principal": cuzco, + "?resource": cedar.NewEntityUID("table", "whatever"), + }, + Entities: cedar.EntityMap{}, + Principal: cuzco, + Action: dropTable, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, + Want: cedar.Allow, + DiagErr: 0, + }, + { + Name: "incorrect-env-size", + Policy: `permit(principal == ?principal,action,resource == ?resource);`, + TemplateID: "template0", + LinkEnv: map[types.SlotID]types.EntityUID{ + "?principal": cuzco, + // Missing ?resource slot + }, + Entities: cedar.EntityMap{}, + Principal: cuzco, + Action: dropTable, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, + Want: cedar.Deny, + LinkErr: true, + }, + { + Name: "missing-template-slot", + Policy: `permit(principal == ?principal,action,resource == ?resource);`, + TemplateID: "template0", + LinkEnv: map[types.SlotID]types.EntityUID{ + "?resource": cedar.NewEntityUID("table", "whatever"), + }, + Entities: cedar.EntityMap{}, + Principal: cuzco, + Action: dropTable, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, + Want: cedar.Deny, + LinkErr: true, + }, + { + Name: "error-in-policy", + Policy: `permit(principal == ?principal,action,resource) when { resource in "foo" };`, + TemplateID: "template0", + LinkEnv: map[types.SlotID]types.EntityUID{"?principal": cuzco}, + Entities: cedar.EntityMap{}, + Principal: cuzco, + Action: dropTable, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, + Want: cedar.Deny, + DiagErr: 1, + }, + { + Name: "permit-unless", + Policy: `permit(principal == ?principal,action,resource) unless { context.x > 100 };`, + TemplateID: "template0", + LinkEnv: map[types.SlotID]types.EntityUID{"?principal": cuzco}, + Entities: cedar.EntityMap{}, + Principal: cuzco, + Action: dropTable, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.NewRecord(cedar.RecordMap{"x": cedar.Long(50)}), + Want: cedar.Allow, + DiagErr: 0, + }, + { + Name: "variable-used-in-wrong-place", + Policy: `permit(principal is coder,action,resource) when { principal == ?principal };`, + TemplateID: "template0", + LinkEnv: map[types.SlotID]types.EntityUID{"?principal": cuzco}, + Entities: cedar.EntityMap{}, + Principal: cuzco, + Action: dropTable, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, + Want: cedar.Deny, + DiagErr: 0, + ParseErr: true, + LinkErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.Name, func(t *testing.T) { + t.Parallel() + ps, err := templates.NewPolicySetFromBytes("policy.cedar", []byte(tt.Policy)) + testutil.Equals(t, err != nil, tt.ParseErr) + + err = ps.LinkTemplate(tt.TemplateID, "link0", tt.LinkEnv) + testutil.Equals(t, err != nil, tt.LinkErr) + + ok, diag := templates.Authorize(ps, tt.Entities, cedar.Request{ + Principal: tt.Principal, + Action: tt.Action, + Resource: tt.Resource, + Context: tt.Context, + }) + testutil.Equals(t, len(diag.Errors), tt.DiagErr) + testutil.Equals(t, ok, tt.Want) + }) + } +} diff --git a/x/exp/templates/policy.go b/x/exp/templates/policy.go new file mode 100644 index 00000000..10f5d050 --- /dev/null +++ b/x/exp/templates/policy.go @@ -0,0 +1,184 @@ +package templates + +import ( + "bytes" + + "github.com/cedar-policy/cedar-go" + "github.com/cedar-policy/cedar-go/types" + + "github.com/cedar-policy/cedar-go/ast" + "github.com/cedar-policy/cedar-go/internal/eval" + "github.com/cedar-policy/cedar-go/internal/json" + "github.com/cedar-policy/cedar-go/internal/parser" + internalast "github.com/cedar-policy/cedar-go/x/exp/ast" +) + +// A Policy is the parsed form of a single Cedar language policy statement. +type Policy struct { + eval eval.BoolEvaler // determines if a policy matches a request. + ast *internalast.Policy +} + +func newPolicy(astIn *internalast.Policy) *Policy { + return &Policy{eval: eval.Compile(astIn), ast: astIn} +} + +// MarshalJSON encodes a single Policy statement in the JSON format specified by the [Cedar documentation]. +// +// [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html +func (p *Policy) MarshalJSON() ([]byte, error) { + jsonPolicy := (*json.Policy)(p.ast) + return jsonPolicy.MarshalJSON() +} + +// UnmarshalJSON parses and compiles a single Policy statement in the JSON format specified by the [Cedar documentation]. +// +// [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html +func (p *Policy) UnmarshalJSON(b []byte) error { + var jsonPolicy json.Policy + if err := jsonPolicy.UnmarshalJSON(b); err != nil { + return err + } + + *p = *newPolicy((*internalast.Policy)(&jsonPolicy)) + return nil +} + +// MarshalCedar encodes a single Policy statement in the human-readable format specified by the [Cedar documentation]. +// +// [Cedar documentation]: https://docs.cedarpolicy.com/policies/syntax-grammar.html +func (p *Policy) MarshalCedar() []byte { + cedarPolicy := (*parser.Policy)(p.ast) + + var buf bytes.Buffer + cedarPolicy.MarshalCedar(&buf) + + return buf.Bytes() +} + +// UnmarshalCedar parses and compiles a single Policy statement in the human-readable format specified by the [Cedar documentation]. +// +// [Cedar documentation]: https://docs.cedarpolicy.com/policies/syntax-grammar.html +func (p *Policy) UnmarshalCedar(b []byte) error { + var cedarPolicy parser.Policy + if err := cedarPolicy.UnmarshalCedar(b); err != nil { + return err + } + *p = *newPolicy((*internalast.Policy)(&cedarPolicy)) + return nil +} + +// NewPolicyFromAST lets you create a new policy statement from a programmatically created AST. +// Do not modify the *ast.Policy after passing it into NewPolicyFromAST. +func NewPolicyFromAST(astIn *ast.Policy) *Policy { + p := newPolicy((*internalast.Policy)(astIn)) + return p +} + +// Annotations retrieves the annotations associated with this policy. +func (p *Policy) Annotations() cedar.Annotations { + res := make(cedar.Annotations, len(p.ast.Annotations)) + for _, e := range p.ast.Annotations { + res[e.Key] = e.Value + } + return res +} + +// Effect retrieves the effect of this policy. +func (p *Policy) Effect() cedar.Effect { + return cedar.Effect(p.ast.Effect) +} + +// Position retrieves the position of this policy. +func (p *Policy) Position() cedar.Position { + return cedar.Position(p.ast.Position) +} + +// SetFilename sets the filename of this policy. +func (p *Policy) SetFilename(fileName string) { + p.ast.Position.Filename = fileName +} + +// AST retrieves the AST of this policy. Do not modify the AST, as the +// compiled policy will no longer be in sync with the AST. +func (p *Policy) AST() *ast.Policy { + return (*ast.Policy)(p.ast) +} + +// We use parser.Policy as the underlying type for Template because +// a templates.Policy or cedar.Policy would be compiled, however, a Template +// is not compilable until it is linked with slot values to create a concrete policy. + +// Template represents a Cedar policy template that can be linked with slot values +// to create concrete policies. It's a wrapper around the internal parser.Policy type. +type Template parser.Policy + +// newTemplate creates a new Template from the given internal AST Policy. +func newTemplate(astIn *internalast.Policy) *Template { + t := (*Template)(astIn) + return t +} + +// MarshalCedar serializes the Template into its Cedar language representation. +// Returns the serialized template as a byte slice. +func (p *Template) MarshalCedar() []byte { + cedarPolicy := (*parser.Policy)(p) + + var buf bytes.Buffer + cedarPolicy.MarshalCedar(&buf) + + return buf.Bytes() +} + +// UnmarshalCedar parses and compiles a single Template statement in the human-readable format specified by the Cedar documentation. +// Returns an error if parsing fails. +func (p *Template) UnmarshalCedar(b []byte) error { + var cedarPolicy parser.Policy + if err := cedarPolicy.UnmarshalCedar(b); err != nil { + return err + } + + *p = *newTemplate((*internalast.Policy)(&cedarPolicy)) + + return nil +} + +// MarshalJSON encodes a single Template statement in the JSON format specified by the Cedar documentation. +// Returns the JSON-encoded template as a byte slice, or an error if encoding fails. +func (p *Template) MarshalJSON() ([]byte, error) { + policyAST := (*internalast.Policy)(p) + jsonPolicy := (*json.Policy)(policyAST) + + return jsonPolicy.MarshalJSON() +} + +// UnmarshalJSON parses and compiles a single Template statement in the JSON format specified by the Cedar documentation. +// Returns an error if parsing fails. +func (p *Template) UnmarshalJSON(b []byte) error { + var jsonPolicy json.Policy + if err := jsonPolicy.UnmarshalJSON(b); err != nil { + return err + } + + *p = *newTemplate((*internalast.Policy)(&jsonPolicy)) + + return nil +} + +// SetFilename sets the filename of this template. +// This is useful for error reporting and debugging purposes. +func (p *Template) SetFilename(fileName string) { + p.Position.Filename = fileName +} + +// Slots returns the slot IDs used in this template. +func (p *Template) Slots() []types.SlotID { + policyAST := (*internalast.Policy)(p) + return policyAST.Slots() +} + +// AST retrieves the AST of this Template. Do not modify the AST. +func (p *Template) AST() *ast.Policy { + policyAST := (*internalast.Policy)(p) + return (*ast.Policy)(policyAST) +} diff --git a/x/exp/templates/policy_list.go b/x/exp/templates/policy_list.go new file mode 100644 index 00000000..94fc610f --- /dev/null +++ b/x/exp/templates/policy_list.go @@ -0,0 +1,86 @@ +package templates + +import ( + "bytes" + "fmt" + + "github.com/cedar-policy/cedar-go/ast" + "github.com/cedar-policy/cedar-go/internal/parser" +) + +// PolicyList represents a list of un-named Policy's. Cedar documents, unlike the PolicySet form, don't have a means of +// naming individual policies. +type PolicyList struct { + StaticPolicies []*Policy // StaticPolicies holds the list of static (non-template) policies. + Templates []*Template // Templates holds the list of policy templates. +} + +// NewPolicyListFromBytes creates a PolicyList from the given Cedar policy document bytes and assigns the provided file name +// to each policy and template for position tracking. Returns an error if parsing fails. +func NewPolicyListFromBytes(fileName string, document []byte) (PolicyList, error) { + var policySlice PolicyList + if err := policySlice.UnmarshalCedar(document); err != nil { + return PolicyList{}, err + } + for _, p := range policySlice.StaticPolicies { + p.SetFilename(fileName) + } + + for _, p := range policySlice.Templates { + p.SetFilename(fileName) + } + + return policySlice, nil +} + +// UnmarshalCedar parses a concatenation of un-named Cedar policy statements from the provided byte slice and populates +// the PolicyList with static policies and templates. Returns an error if parsing fails. +func (p *PolicyList) UnmarshalCedar(b []byte) error { + var res parser.PolicySlice + if err := res.UnmarshalCedar(b); err != nil { + return fmt.Errorf("parser error: %w", err) + } + + staticPolicies := make([]*Policy, 0, len(res.StaticPolicies)) + for _, p := range res.StaticPolicies { + newPolicy := NewPolicyFromAST((*ast.Policy)(p)) + staticPolicies = append(staticPolicies, newPolicy) + } + + templates := make([]*Template, 0, len(res.Templates)) + for _, p := range res.Templates { + t := Template(*p) + templates = append(templates, &t) + } + + p.StaticPolicies = staticPolicies + p.Templates = templates + + return nil +} + +// MarshalCedar emits a concatenated Cedar representation of the policies and templates in the PolicyList as a byte slice. +func (p PolicyList) MarshalCedar() []byte { + var buf bytes.Buffer + for i, policy := range p.StaticPolicies { + buf.Write(policy.MarshalCedar()) + + if i < len(p.StaticPolicies)-1 { + buf.WriteString("\n\n") + } + } + + if len(p.Templates) > 0 { + buf.WriteString("\n\n") + } + + for i, template := range p.Templates { + buf.Write(template.MarshalCedar()) + + if i < len(p.Templates)-1 { + buf.WriteString("\n\n") + } + } + + return buf.Bytes() +} diff --git a/x/exp/templates/policy_list_test.go b/x/exp/templates/policy_list_test.go new file mode 100644 index 00000000..d7f8944a --- /dev/null +++ b/x/exp/templates/policy_list_test.go @@ -0,0 +1,56 @@ +package templates_test + +import ( + "github.com/cedar-policy/cedar-go/x/exp/templates" + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" +) + +func TestPolicySlice(t *testing.T) { + t.Parallel() + + policiesStr := `permit ( + principal, + action == Action::"editPhoto", + resource +) +when { resource.owner == principal }; + +forbid ( + principal in Groups::"bannedUsers", + action, + resource +);` + + policies, err := templates.NewPolicyListFromBytes("", []byte(policiesStr)) + testutil.OK(t, err) + testutil.Equals(t, string(policies.MarshalCedar()), policiesStr) +} + +func TestPolicyWithTemplateSlice(t *testing.T) { + t.Parallel() + + policiesStr := `permit ( + principal, + action == Action::"editPhoto", + resource +) +when { resource.owner == principal }; + +forbid ( + principal in Groups::"bannedUsers", + action, + resource +); + +permit ( + principal == ?principal, + action, + resource +);` + + policies, err := templates.NewPolicyListFromBytes("", []byte(policiesStr)) + testutil.OK(t, err) + testutil.Equals(t, string(policies.MarshalCedar()), policiesStr) +} diff --git a/x/exp/templates/policy_set.go b/x/exp/templates/policy_set.go new file mode 100644 index 00000000..7cee1ef2 --- /dev/null +++ b/x/exp/templates/policy_set.go @@ -0,0 +1,327 @@ +// Package templates provides an implementation of the Cedar language authorizer. +package templates + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/cedar-policy/cedar-go" + "github.com/cedar-policy/cedar-go/internal/parser" + "github.com/cedar-policy/cedar-go/types" + internalast "github.com/cedar-policy/cedar-go/x/exp/ast" + "iter" + "maps" + "slices" + + internaljson "github.com/cedar-policy/cedar-go/internal/json" +) + +type PolicyMap map[cedar.PolicyID]*Policy + +// All returns an iterator over the policy IDs and policies in the PolicyMap. +func (p PolicyMap) All() iter.Seq2[cedar.PolicyID, *Policy] { + return maps.All(p) +} + +// PolicySet is a set of named policies against which a request can be authorized. +type PolicySet struct { + // policies are stored internally so we can handle performance, concurrency bookkeeping however we want + staticPolicies PolicyMap + linkedPolicies map[cedar.PolicyID]*LinkedPolicy + + templates map[cedar.PolicyID]*Template +} + +// NewPolicySet creates a new, empty PolicySet +func NewPolicySet() *PolicySet { + return &PolicySet{ + staticPolicies: PolicyMap{}, + templates: make(map[cedar.PolicyID]*Template), + linkedPolicies: make(map[cedar.PolicyID]*LinkedPolicy), + } +} + +// NewPolicySetFromBytes will create a PolicySet from the given text document with the given file name used in Position +// data. If there is an error parsing the document, it will be returned. +// +// NewPolicySetFromBytes assigns default PolicyIDs to the policies contained in fileName in the format "policy" where +// is incremented for each new policy found in the file. +func NewPolicySetFromBytes(fileName string, document []byte) (*PolicySet, error) { + policySlice, err := NewPolicyListFromBytes(fileName, document) + if err != nil { + return &PolicySet{}, err + } + policyMap := make(PolicyMap, len(policySlice.StaticPolicies)) + for i, p := range policySlice.StaticPolicies { + policyID := cedar.PolicyID(fmt.Sprintf("policy%d", i)) + policyMap[policyID] = p + } + + templateMap := make(map[cedar.PolicyID]*Template, len(policySlice.Templates)) + for i, p := range policySlice.Templates { + policyID := cedar.PolicyID(fmt.Sprintf("template%d", i)) + templateMap[policyID] = p + } + + return &PolicySet{staticPolicies: policyMap, templates: templateMap, linkedPolicies: make(map[cedar.PolicyID]*LinkedPolicy)}, nil +} + +// Get returns the Policy with the given ID. If a policy with the given ID +// does not exist, nil is returned. +func (p *PolicySet) Get(policyID cedar.PolicyID) *Policy { + return p.staticPolicies[policyID] +} + +// Add inserts or updates a policy with the given ID. Returns true if a policy +// with the given ID did not already exist in the set. +func (p *PolicySet) Add(policyID cedar.PolicyID, policy *Policy) bool { + _, exists := p.staticPolicies[policyID] + p.staticPolicies[policyID] = policy + return !exists +} + +// Remove removes a policy from the PolicySet. Returns true if a policy with +// the given ID already existed in the set. +func (p *PolicySet) Remove(policyID cedar.PolicyID) bool { + _, staticExists := p.staticPolicies[policyID] + delete(p.staticPolicies, policyID) + + _, linkExists := p.linkedPolicies[policyID] + delete(p.linkedPolicies, policyID) + + return staticExists || linkExists +} + +// Map returns a new PolicyMap instance of the policies in the PolicySet. +// +// Deprecated: use the iterator returned by All() like so: maps.Collect(ps.All()) +func (p *PolicySet) Map() PolicyMap { + return maps.Clone(p.staticPolicies) +} + +// MarshalCedar emits a concatenated Cedar representation of a PolicySet. The policy names are stripped, but policies +// are emitted in lexicographical order by ID. +func (p *PolicySet) MarshalCedar() []byte { + ids := make([]cedar.PolicyID, 0, len(p.staticPolicies)) + for k := range p.staticPolicies { + ids = append(ids, k) + } + slices.Sort(ids) + + var buf bytes.Buffer + i := 0 + for _, id := range ids { + policy := p.staticPolicies[id] + buf.Write(policy.MarshalCedar()) + + if i < len(p.staticPolicies)-1 { + buf.WriteString("\n\n") + } + i++ + } + return buf.Bytes() +} + +// MarshalJSON encodes a PolicySet in the JSON format specified by the [Cedar documentation]. +// +// [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html +func (p *PolicySet) MarshalJSON() ([]byte, error) { + jsonPolicySet := internaljson.PolicySetJSON{ + StaticPolicies: make(internaljson.PolicySet, len(p.staticPolicies)), + Templates: make(internaljson.TemplateSet, len(p.templates)), + TemplateLinks: make([]internaljson.LinkedPolicy, 0, len(p.linkedPolicies)), + } + for k, v := range p.staticPolicies { + jsonPolicySet.StaticPolicies[string(k)] = (*internaljson.Policy)(v.AST()) + } + for k, v := range p.templates { + jsonPolicySet.Templates[string(k)] = (*internaljson.Policy)(v.AST()) + } + for _, v := range p.linkedPolicies { + lp := internaljson.LinkedPolicy{ + TemplateID: string(v.templateID), + LinkID: string(v.linkID), + Values: make(map[string]types.ImplicitlyMarshaledEntityUID, len(v.slotEnv)), + } + + for slotID, entityUID := range v.slotEnv { + lp.Values[string(slotID)] = types.ImplicitlyMarshaledEntityUID(entityUID) + } + + jsonPolicySet.TemplateLinks = append(jsonPolicySet.TemplateLinks, lp) + } + + return json.Marshal(jsonPolicySet) +} + +// UnmarshalJSON parses and compiles a PolicySet in the JSON format specified by the [Cedar documentation]. +// +// [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html +func (p *PolicySet) UnmarshalJSON(b []byte) error { + var jsonPolicySet internaljson.PolicySetJSON + if err := json.Unmarshal(b, &jsonPolicySet); err != nil { + return err + } + *p = PolicySet{ + staticPolicies: make(PolicyMap, len(jsonPolicySet.StaticPolicies)), + templates: make(map[cedar.PolicyID]*Template, len(jsonPolicySet.Templates)), + linkedPolicies: make(map[cedar.PolicyID]*LinkedPolicy), + } + for k, v := range jsonPolicySet.StaticPolicies { + p.staticPolicies[cedar.PolicyID(k)] = newPolicy((*internalast.Policy)(v)) + } + for k, v := range jsonPolicySet.Templates { + p.templates[cedar.PolicyID(k)] = newTemplate((*internalast.Policy)(v)) + } + for _, v := range jsonPolicySet.TemplateLinks { + lp := &LinkedPolicy{ + templateID: cedar.PolicyID(v.TemplateID), + linkID: cedar.PolicyID(v.LinkID), + slotEnv: make(map[types.SlotID]types.EntityUID, len(v.Values)), + } + + for slotID, entityUID := range v.Values { + slotIDTyped := types.SlotID(slotID) + entityUIDTyped := types.EntityUID(entityUID) + + lp.slotEnv[slotIDTyped] = entityUIDTyped + } + + p.linkedPolicies[cedar.PolicyID(v.LinkID)] = lp + } + + return nil +} + +// All returns an iterator over the (PolicyID, *Policy) tuples in the PolicySet +func (p *PolicySet) All() iter.Seq2[cedar.PolicyID, *Policy] { + return func(yield func(cedar.PolicyID, *Policy) bool) { + for k, v := range p.staticPolicies { + if !yield(k, v) { + break + } + } + + for k, v := range p.linkedPolicies { + // Render links on read to make template changes propagate + policy, err := p.render(*v) + if err != nil { //todo: think how to propagate this error + continue + } + + if !yield(k, policy) { + break + } + } + } +} + +func (p *PolicySet) render(link LinkedPolicy) (*Policy, error) { + template := p.GetTemplate(link.templateID) + if template == nil { + return nil, fmt.Errorf("no such template %q", link.templateID) + } + + pTemplate := parser.Template(*template) + + policy, err := parser.RenderLinkedPolicy(&pTemplate, link.slotEnv) + if err != nil { + return nil, err + } + + astPolicy := internalast.Policy(policy) + + return newPolicy(&astPolicy), nil +} + +// LinkedPolicy represents a template that has been linked with specific slot values. +// It's a wrapper around the internal parser.LinkedPolicy type. +//type LinkedPolicy parser.LinkedPolicy + +type LinkedPolicy struct { + templateID cedar.PolicyID + linkID cedar.PolicyID + slotEnv map[types.SlotID]types.EntityUID +} + +// TemplateID returns the PolicyID of the template associated with this LinkedPolicy. +func (l *LinkedPolicy) TemplateID() cedar.PolicyID { + return l.templateID +} + +// LinkID returns the PolicyID of this LinkedPolicy. +func (l *LinkedPolicy) LinkID() cedar.PolicyID { + return l.linkID +} + +// LinkTemplate creates a LinkedPolicy by binding slot values to a template. +// Parameters: +// - template: The policy template to link +// - templateID: The identifier for the template +// - linkID: The identifier for the resulting linked policy +// - slotEnv: A map of slot IDs to entity UIDs that will be substituted into the template +// +// Returns a LinkedPolicy that can be rendered into a concrete Policy. +func (p *PolicySet) LinkTemplate(templateID cedar.PolicyID, linkID cedar.PolicyID, slotEnv map[types.SlotID]types.EntityUID) error { + _, exists := p.staticPolicies[linkID] + if exists { + return fmt.Errorf("link ID %s already exists in the policy set", linkID) + } + + template := p.GetTemplate(templateID) + if template == nil { + return fmt.Errorf("template %s not found", templateID) + } + + if len(slotEnv) < len(template.Slots()) { + return fmt.Errorf("template %s requires %d variables, slot env has %d", templateID, len(template.Slots()), len(slotEnv)) + } + + for _, slotID := range template.Slots() { + if _, ok := slotEnv[slotID]; !ok { + return fmt.Errorf("template %s requires variable %s, missing from slot env", templateID, slotID) + } + } + + link := LinkedPolicy{templateID, linkID, slotEnv} + p.linkedPolicies[linkID] = &link + + return nil +} + +// GetLinkedPolicy returns the LinkedPolicy associated with the given link ID. +// If the linked policy does not exist, it returns nil. +func (p *PolicySet) GetLinkedPolicy(linkID cedar.PolicyID) *LinkedPolicy { + return p.linkedPolicies[linkID] +} + +// GetTemplate returns the Template with the given ID. +// If a template with the given ID does not exist, nil is returned. +func (p PolicySet) GetTemplate(templateID cedar.PolicyID) *Template { + return p.templates[templateID] +} + +// AddTemplate inserts or updates a template with the given ID. +// Returns true if a template with the given ID did not already exist in the set. +func (p *PolicySet) AddTemplate(templateID cedar.PolicyID, template *Template) bool { + _, exists := p.templates[templateID] + p.templates[templateID] = template + return !exists +} + +// RemoveTemplate removes a template from the PolicySet. +// Returns true if a template with the given ID already existed in the set. +func (p *PolicySet) RemoveTemplate(templateID cedar.PolicyID) bool { + _, exists := p.templates[templateID] + if exists { + // Remove all linked policies that reference this template + for linkID, link := range p.linkedPolicies { + if link.templateID == templateID { + delete(p.linkedPolicies, linkID) + } + } + } + + delete(p.templates, templateID) + return exists +} diff --git a/x/exp/templates/policy_set_test.go b/x/exp/templates/policy_set_test.go new file mode 100644 index 00000000..c5e11e2e --- /dev/null +++ b/x/exp/templates/policy_set_test.go @@ -0,0 +1,501 @@ +package templates_test + +import ( + "fmt" + "maps" + "testing" + + "github.com/cedar-policy/cedar-go" + "github.com/cedar-policy/cedar-go/ast" + "github.com/cedar-policy/cedar-go/internal/parser" + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" + "github.com/cedar-policy/cedar-go/x/exp/templates" +) + +func TestPolicyMap(t *testing.T) { + t.Parallel() + t.Run("All", func(t *testing.T) { + t.Parallel() + pm := templates.PolicyMap{ + "foo": templates.NewPolicyFromAST(ast.Permit()), + "bar": templates.NewPolicyFromAST(ast.Permit()), + } + + got := maps.Collect(pm.All()) + testutil.Equals(t, got, pm) + }) +} + +func TestNewPolicySetFromFile(t *testing.T) { + t.Parallel() + t.Run("err-in-tokenize", func(t *testing.T) { + t.Parallel() + _, err := templates.NewPolicySetFromBytes("policy.cedar", []byte(`"`)) + testutil.Error(t, err) + }) + t.Run("err-in-parse", func(t *testing.T) { + t.Parallel() + _, err := templates.NewPolicySetFromBytes("policy.cedar", []byte(`err`)) + testutil.Error(t, err) + }) + t.Run("annotations", func(t *testing.T) { + t.Parallel() + ps, err := templates.NewPolicySetFromBytes("policy.cedar", []byte(`@key("value") permit (principal, action, resource);`)) + testutil.OK(t, err) + testutil.Equals(t, ps.Get("policy0").Annotations(), cedar.Annotations{"key": "value"}) + }) +} + +func TestUpsertPolicy(t *testing.T) { + t.Parallel() + t.Run("insert", func(t *testing.T) { + t.Parallel() + + policy0 := templates.NewPolicyFromAST(ast.Forbid()) + + var policy1 templates.Policy + testutil.OK(t, policy1.UnmarshalJSON( + []byte(`{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`), + )) + + ps := templates.NewPolicySet() + added := ps.Add("policy0", policy0) + testutil.Equals(t, added, true) + added = ps.Add("policy1", &policy1) + testutil.Equals(t, added, true) + + testutil.Equals(t, ps.Get("policy0"), policy0) + testutil.Equals(t, ps.Get("policy1"), &policy1) + testutil.Equals(t, ps.Get("policy2"), nil) + }) + t.Run("upsert", func(t *testing.T) { + t.Parallel() + + ps := templates.NewPolicySet() + + p1 := templates.NewPolicyFromAST(ast.Forbid()) + ps.Add("a wavering policy", p1) + + p2 := templates.NewPolicyFromAST(ast.Permit()) + added := ps.Add("a wavering policy", p2) + testutil.Equals(t, added, false) + + testutil.Equals(t, ps.Get("a wavering policy"), p2) + }) +} + +func TestDeletePolicy(t *testing.T) { + t.Parallel() + t.Run("delete non-existent", func(t *testing.T) { + t.Parallel() + + ps := templates.NewPolicySet() + + existed := ps.Remove("not a policy") + testutil.Equals(t, existed, false) + }) + t.Run("delete existing", func(t *testing.T) { + t.Parallel() + + ps := templates.NewPolicySet() + + p1 := templates.NewPolicyFromAST(ast.Forbid()) + ps.Add("a policy", p1) + existed := ps.Remove("a policy") + testutil.Equals(t, existed, true) + + testutil.Equals(t, ps.Get("a policy"), nil) + }) +} + +func TestNewPolicySetFromSlice(t *testing.T) { + t.Parallel() + + policiesStr := `permit ( + principal, + action == Action::"editPhoto", + resource +) +when { resource.owner == principal }; + +forbid ( + principal in Groups::"bannedUsers", + action, + resource +);` + + policies, err := templates.NewPolicyListFromBytes("", []byte(policiesStr)) + testutil.OK(t, err) + + ps := templates.NewPolicySet() + for i, p := range policies.StaticPolicies { + p.SetFilename("example.cedar") + ps.Add(cedar.PolicyID(fmt.Sprintf("policy%d", i)), p) + } + + testutil.Equals(t, ps.Get("policy0").Effect(), cedar.Permit) + testutil.Equals(t, ps.Get("policy1").Effect(), cedar.Forbid) + + testutil.Equals(t, string(ps.MarshalCedar()), policiesStr) + +} + +func TestPolicySetMap(t *testing.T) { + t.Parallel() + ps, err := templates.NewPolicySetFromBytes("", []byte(`permit (principal, action, resource);`)) + testutil.OK(t, err) + m := maps.Collect(ps.All()) + testutil.Equals(t, len(m), 1) +} + +func TestPolicySetJSON(t *testing.T) { + t.Parallel() + t.Run("UnmarshalError", func(t *testing.T) { + t.Parallel() + var ps templates.PolicySet + err := ps.UnmarshalJSON([]byte(`!@#$`)) + testutil.Error(t, err) + }) + + t.Run("UnmarshalOK", func(t *testing.T) { + t.Parallel() + var ps templates.PolicySet + err := ps.UnmarshalJSON([]byte(`{"staticPolicies":{"policy0":{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}},"templates":{"template0":{"effect":"permit","principal":{"op":"==","slot":"?principal"},"action":{"op":"All"},"resource":{"op":"All"}}},"templateLinks":[{"templateId":"template0","newId":"linked0","values":{"?principal":{"type":"User","id":"alice"}}}]}`)) + testutil.OK(t, err) + testutil.Equals(t, len(maps.Collect(ps.All())), 2) + testutil.Equals(t, ps.GetTemplate("template0") != nil, true) + testutil.Equals(t, ps.GetLinkedPolicy("linked0") != nil, true) + }) + + t.Run("MarshalOK", func(t *testing.T) { + t.Parallel() + ps, err := templates.NewPolicySetFromBytes("", []byte(`permit (principal, action, resource); + +permit (principal == ?principal, action, resource);`)) + testutil.OK(t, err) + + err = ps.LinkTemplate("template0", "linked0", map[types.SlotID]types.EntityUID{ + "?principal": types.NewEntityUID("User", "alice"), + }) + testutil.OK(t, err) + + out, err := ps.MarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(out), `{"staticPolicies":{"policy0":{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"},"conditions":[]}},"templates":{"template0":{"effect":"permit","principal":{"op":"==","slot":"?principal"},"action":{"op":"All"},"resource":{"op":"All"},"conditions":[]}},"templateLinks":[{"templateId":"template0","newId":"linked0","values":{"?principal":{"type":"User","id":"alice"}}}]}`) + }) +} + +func TestAll(t *testing.T) { + t.Parallel() + t.Run("all", func(t *testing.T) { + t.Parallel() + + policies := map[cedar.PolicyID]*templates.Policy{ + "policy0": templates.NewPolicyFromAST(ast.Forbid()), + "policy1": templates.NewPolicyFromAST(ast.Forbid()), + "policy2": templates.NewPolicyFromAST(ast.Forbid()), + } + + ps := templates.NewPolicySet() + for k, v := range policies { + ps.Add(k, v) + } + + got := map[cedar.PolicyID]*templates.Policy{} + for k, v := range ps.All() { + got[k] = v + } + + testutil.Equals(t, policies, got) + }) + + t.Run("break early", func(t *testing.T) { + t.Parallel() + + policies := map[cedar.PolicyID]*templates.Policy{ + "policy0": templates.NewPolicyFromAST(ast.Forbid()), + "policy1": templates.NewPolicyFromAST(ast.Forbid()), + "policy2": templates.NewPolicyFromAST(ast.Forbid()), + } + + ps := templates.NewPolicySet() + for k, v := range policies { + ps.Add(k, v) + } + + got := map[cedar.PolicyID]*templates.Policy{} + for k, v := range ps.All() { + got[k] = v + if len(got) == 2 { + break + } + } + + testutil.Equals(t, len(got), 2) + for k, v := range got { + testutil.Equals(t, policies[k], v) + } + }) +} + +func TestPolicySetTemplateManagement(t *testing.T) { + t.Run("template round-trip", func(t *testing.T) { + policySet := templates.NewPolicySet() + + var templateBody parser.Policy + templateString := `@id("test_template") +permit ( + principal == ?principal, + action, + resource +);` + testutil.OK(t, templateBody.UnmarshalCedar([]byte(templateString))) + template := templates.Template(templateBody) + + templateID := cedar.PolicyID("test_template_id") + added := policySet.AddTemplate(templateID, &template) + testutil.Equals(t, added, true) + + retrievedTemplate := policySet.GetTemplate(templateID) + testutil.Equals(t, retrievedTemplate != nil, true) + + originalBytes := template.MarshalCedar() + retrievedBytes := retrievedTemplate.MarshalCedar() + testutil.Equals(t, string(retrievedBytes), string(originalBytes)) + + removed := policySet.RemoveTemplate(templateID) + testutil.Equals(t, removed, true) + + retrievedTemplateAfterRemoval := policySet.GetTemplate(templateID) + testutil.Equals(t, retrievedTemplateAfterRemoval, (*templates.Template)(nil)) + }) + + t.Run("remove non-existent template", func(t *testing.T) { + policySet := templates.NewPolicySet() + templateID := cedar.PolicyID("non_existent_template") + removed := policySet.RemoveTemplate(templateID) + testutil.Equals(t, removed, false) + }) + + t.Run("add template with existing ID", func(t *testing.T) { + policySet := templates.NewPolicySet() + templateID := cedar.PolicyID("duplicate_template_id") + + var templateBody parser.Policy + templateString := `@id("test_template") +permit ( + principal, + action, + resource +);` + testutil.OK(t, templateBody.UnmarshalCedar([]byte(templateString))) + template := templates.Template(templateBody) + + // First add should succeed + isNew := policySet.AddTemplate(templateID, &template) + testutil.Equals(t, isNew, true) + + // Second add with same ID should return false + isNew = policySet.AddTemplate(templateID, &template) + testutil.Equals(t, isNew, false) + }) + + t.Run("cannot use link id already used by static policy", func(t *testing.T) { + templateString := `permit ( + principal == ?principal, + action, + resource +); + +permit ( + principal, + action, + resource +);` + templateID := cedar.PolicyID("template0") + policyID := cedar.PolicyID("policy0") + + policySet, err := templates.NewPolicySetFromBytes("test.cedar", []byte(templateString)) + testutil.OK(t, err) + + // Link a policy to the template + //linkID := cedar.PolicyID("linked_policy_id") + env := map[types.SlotID]types.EntityUID{ + "?principal": types.NewEntityUID("User", "alice"), + } + err = policySet.LinkTemplate(templateID, policyID, env) + testutil.Error(t, err) + }) + + t.Run("removing template removes linked policies", func(t *testing.T) { + templateString := `permit ( + principal == ?principal, + action, + resource +);` + templateID := cedar.PolicyID("template0") + + policySet, err := templates.NewPolicySetFromBytes("test.cedar", []byte(templateString)) + testutil.OK(t, err) + + // Link a policy to the template + linkID := cedar.PolicyID("linked_policy_id") + env := map[types.SlotID]types.EntityUID{ + "?principal": types.NewEntityUID("User", "alice"), + } + err = policySet.LinkTemplate(templateID, linkID, env) + testutil.OK(t, err) + + // Ensure the linked policy exists + linkedPolicy := policySet.GetLinkedPolicy(linkID) + testutil.Equals(t, linkedPolicy != nil, true) + + // Remove the template + removed := policySet.RemoveTemplate(templateID) + testutil.Equals(t, removed, true) + + // The linked policy should also be removed + linkedPolicyAfterRemoval := policySet.GetLinkedPolicy(linkID) + testutil.Equals(t, linkedPolicyAfterRemoval == nil, true) + }) + + t.Run("remove method can also remove linked policy", func(t *testing.T) { + templateString := `permit ( + principal == ?principal, + action, + resource +);` + templateID := cedar.PolicyID("template0") + + policySet, err := templates.NewPolicySetFromBytes("test.cedar", []byte(templateString)) + testutil.OK(t, err) + + // Link a policy to the template + linkID := cedar.PolicyID("linked_policy_id") + env := map[types.SlotID]types.EntityUID{ + "?principal": types.NewEntityUID("User", "alice"), + } + err = policySet.LinkTemplate(templateID, linkID, env) + testutil.OK(t, err) + + // Ensure the linked policy exists + linkedPolicy := policySet.GetLinkedPolicy(linkID) + testutil.Equals(t, linkedPolicy != nil, true) + + // Remove the template + removed := policySet.Remove(linkID) + testutil.Equals(t, removed, true) + + // The linked policy should also be removed + linkedPolicyAfterRemoval := policySet.GetLinkedPolicy(linkID) + testutil.Equals(t, linkedPolicyAfterRemoval == nil, true) + }) +} + +func TestLinkTemplateToPolicy(t *testing.T) { + linkTests := []struct { + Name string + TemplateString string + LinkID cedar.PolicyID + Env map[types.SlotID]types.EntityUID + Want string + }{ + { + "principal ScopeTypeEq", + `permit ( + principal == ?principal, + action, + resource +);`, + "scope_eq_link", + map[types.SlotID]types.EntityUID{"?principal": types.NewEntityUID("User", "bob")}, + `{"effect":"permit","principal":{"op":"==","entity":{"type":"User","id":"bob"}},"action":{"op":"All"},"resource":{"op":"All"}}`, + }, + { + "principal ScopeTypeIn", + `permit ( + principal in ?principal, + action, + resource +);`, + "scope_in_link", + map[types.SlotID]types.EntityUID{"?principal": types.NewEntityUID("User", "charlie")}, + `{"effect":"permit","principal":{"op":"in","entity":{"type":"User","id":"charlie"}},"action":{"op":"All"},"resource":{"op":"All"}}`, + }, + { + "principal ScopeTypeIsIn", + `permit ( + principal is User in ?principal, + action, + resource +);`, + "scope_isin_link", + map[types.SlotID]types.EntityUID{"?principal": types.NewEntityUID("User", "dave")}, + `{"effect":"permit","principal":{"op":"is","entity_type":"User","in":{"entity":{"type":"User","id":"dave"}}},"action":{"op":"All"},"resource":{"op":"All"}}`, + }, + { + "resource ScopeTypeEq", + `permit ( + principal, + action, + resource == ?resource +);`, + "scope_eq_link", + map[types.SlotID]types.EntityUID{"?resource": types.NewEntityUID("Album", "trip")}, + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"==","entity":{"type":"Album","id":"trip"}}}`, + }, + { + "resource ScopeTypeIn", + `permit ( + principal, + action, + resource in ?resource +);`, + "scope_in_link", + map[types.SlotID]types.EntityUID{"?resource": types.NewEntityUID("Album", "trip")}, + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"in","entity":{"type":"Album","id":"trip"}}}`, + }, + { + "resource ScopeTypeIsIn", + `permit ( + principal, + action, + resource is Album in ?resource +);`, + "scope_isin_link", + map[types.SlotID]types.EntityUID{"?resource": types.NewEntityUID("Album", "trip")}, + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"is","entity_type":"Album","in":{"entity":{"type":"Album","id":"trip"}}}}`, + }, + } + + for _, tt := range linkTests { + t.Run(tt.Name, func(t *testing.T) { + t.Parallel() + + policySet, err := templates.NewPolicySetFromBytes("test.cedar", []byte(tt.TemplateString)) + testutil.OK(t, err) + + templateID := cedar.PolicyID("template0") + + err = policySet.LinkTemplate(templateID, tt.LinkID, tt.Env) + testutil.OK(t, err) + + linkedPolicy := policySet.GetLinkedPolicy(tt.LinkID) + + testutil.Equals(t, linkedPolicy.LinkID(), tt.LinkID) + testutil.Equals(t, linkedPolicy.TemplateID(), templateID) + + for policyID, policy := range policySet.All() { + if policyID == tt.LinkID { + pj, err := policy.MarshalJSON() + testutil.OK(t, err) + + testutil.Equals(t, string(pj), tt.Want) + + break + } + } + }) + } +} diff --git a/x/exp/templates/policy_test.go b/x/exp/templates/policy_test.go new file mode 100644 index 00000000..8d80f473 --- /dev/null +++ b/x/exp/templates/policy_test.go @@ -0,0 +1,158 @@ +package templates_test + +import ( + "bytes" + "encoding/json" + "github.com/cedar-policy/cedar-go/x/exp/templates" + "testing" + + "github.com/cedar-policy/cedar-go" + "github.com/cedar-policy/cedar-go/ast" + "github.com/cedar-policy/cedar-go/internal/testutil" +) + +func prettifyJSON(in []byte) []byte { + var buf bytes.Buffer + _ = json.Indent(&buf, in, "", " ") + return buf.Bytes() +} + +func TestPolicyJSON(t *testing.T) { + t.Parallel() + + // Taken from https://docs.cedarpolicy.com/policies/json-format.html + jsonEncodedPolicy := prettifyJSON([]byte(` + { + "effect": "permit", + "principal": { + "op": "==", + "entity": { "type": "User", "id": "12UA45" } + }, + "action": { + "op": "==", + "entity": { "type": "Action", "id": "view" } + }, + "resource": { + "op": "in", + "entity": { "type": "Folder", "id": "abc" } + }, + "conditions": [ + { + "kind": "when", + "body": { + "==": { + "left": { + ".": { + "left": { + "Var": "context" + }, + "attr": "tls_version" + } + }, + "right": { + "Value": "1.3" + } + } + } + } + ] + }`, + )) + + var policy templates.Policy + testutil.OK(t, policy.UnmarshalJSON(jsonEncodedPolicy)) + + output, err := policy.MarshalJSON() + testutil.OK(t, err) + + testutil.Equals(t, string(prettifyJSON(output)), string(jsonEncodedPolicy)) +} + +func TestTemplateJSON(t *testing.T) { + t.Parallel() + + // Taken from https://docs.cedarpolicy.com/policies/json-format.html + jsonEncodedTemplate := prettifyJSON([]byte(` + { + "effect": "forbid", + "principal": { + "op": "==", + "entity": { "type": "User", "id": "12UA45" } + }, + "action": { + "op": "==", + "entity": { "type": "Action", "id": "view" } + }, + "resource": { + "op": "in", + "slot": "?resource" + }, + "conditions": [] + }`, + )) + + var policy templates.Template + testutil.OK(t, policy.UnmarshalJSON(jsonEncodedTemplate)) + + output, err := policy.MarshalJSON() + testutil.OK(t, err) + + testutil.Equals(t, string(prettifyJSON(output)), string(jsonEncodedTemplate)) +} + +func TestPolicyCedar(t *testing.T) { + t.Parallel() + + // Taken from https://docs.cedarpolicy.com/policies/syntax-policy.html + policyStr := `permit ( + principal, + action == Action::"editPhoto", + resource +) +when { resource.owner == principal };` + + var policy templates.Policy + testutil.OK(t, policy.UnmarshalCedar([]byte(policyStr))) + + testutil.Equals(t, string(policy.MarshalCedar()), policyStr) +} + +func TestTemplateCedar(t *testing.T) { + t.Parallel() + + policyStr := `permit ( + principal == ?principal, + action, + resource == ?resource +) +when { resource.owner == principal };` + + var policy templates.Template + testutil.OK(t, policy.UnmarshalCedar([]byte(policyStr))) + + testutil.Equals(t, string(policy.MarshalCedar()), policyStr) +} + +func TestPolicyAST(t *testing.T) { + t.Parallel() + + astExample := ast.Permit(). + ActionEq(cedar.NewEntityUID("Action", "editPhoto")). + When(ast.Resource().Access("owner").Equal(ast.Principal())) + + _ = templates.NewPolicyFromAST(astExample) +} + +func TestUnmarshalJSONPolicyErr(t *testing.T) { + t.Parallel() + var p templates.Policy + err := p.UnmarshalJSON([]byte("!@#$")) + testutil.Error(t, err) +} + +func TestUnmarshalCedarPolicyErr(t *testing.T) { + t.Parallel() + var p templates.Policy + err := p.UnmarshalCedar([]byte("!@#$")) + testutil.Error(t, err) +}