diff --git a/cedar-language-server/src/policy/completion/provider/scope.rs b/cedar-language-server/src/policy/completion/provider/scope.rs index 3b576b54f4..a3c7e7433a 100644 --- a/cedar-language-server/src/policy/completion/provider/scope.rs +++ b/cedar-language-server/src/policy/completion/provider/scope.rs @@ -93,7 +93,7 @@ fn handle_principal_scope(policy: &Template, info: &ScopeVariableInfo) -> Comple PrincipalOrResourceConstraint::Eq(entity_reference) => create_binary_op_context( Op::eq(), principal_var, - entity_reference.into_expr(SlotId::principal()).into(), + entity_reference.into_expr(SlotId::resource()).into(), ), PrincipalOrResourceConstraint::Is(..) => create_is_context(principal_var), diff --git a/cedar-language-server/src/schema/fold.rs b/cedar-language-server/src/schema/fold.rs index 403963b400..fbdf73f5cc 100644 --- a/cedar-language-server/src/schema/fold.rs +++ b/cedar-language-server/src/schema/fold.rs @@ -79,15 +79,15 @@ pub(crate) fn fold_schema(schema_info: &SchemaInfo) -> Option> .entity_types() .filter_map(|et| et.loc.as_loc_ref()); let action_locs = validator.action_ids().filter_map(|a| a.loc()); - let common_types = validator - .common_types() + let common_types_extended = validator + .common_types_extended() .filter_map(|ct| ct.type_loc.as_loc_ref()); // Combine all locations and create folding ranges let ranges = namespace_locs .chain(entity_type_locs) .chain(action_locs) - .chain(common_types) + .chain(common_types_extended) .unique() .map(|loc| { let src_range = loc.to_range(); diff --git a/cedar-language-server/src/schema/symbols.rs b/cedar-language-server/src/schema/symbols.rs index 8271462b3b..001129e762 100644 --- a/cedar-language-server/src/schema/symbols.rs +++ b/cedar-language-server/src/schema/symbols.rs @@ -124,8 +124,8 @@ pub(crate) fn schema_symbols(schema_info: &SchemaInfo) -> Option = validator - .common_types() + let common_type_extended_symbols: Vec = validator + .common_types_extended() .filter_map(|ct| { ct.name_loc .as_ref() @@ -144,7 +144,7 @@ pub(crate) fn schema_symbols(schema_info: &SchemaInfo) -> Option Expr { self.subexpressions() .filter_map(|exp| match &exp.expr_kind { ExprKind::Slot(slotid) => Some(Slot { - id: *slotid, + id: slotid.clone(), loc: exp.source_loc().into_maybe_loc(), }), _ => None, }) } + /// Iterate over all of the principal or resource slots in this policy AST + pub fn principal_or_resource_slots(&self) -> impl Iterator + '_ { + self.subexpressions() + .filter_map(|exp| match &exp.expr_kind { + ExprKind::Slot(slotid) if slotid.is_principal() || slotid.is_resource() => { + Some(Slot { + id: slotid.clone(), + loc: exp.source_loc().into_maybe_loc(), + }) + } + _ => None, + }) + } + /// Determine if the expression is projectable under partial evaluation /// An expression is projectable if it's guaranteed to never error on evaluation /// This is true if the expression is entirely composed of values or unknowns @@ -1842,7 +1856,7 @@ mod test { let e = Expr::slot(SlotId::principal()); let p = SlotId::principal(); let r = SlotId::resource(); - let set: HashSet = HashSet::from_iter([p]); + let set: HashSet = HashSet::from_iter([p.clone()]); assert_eq!(set, e.slots().map(|slot| slot.id).collect::>()); let e = Expr::or( Expr::slot(SlotId::principal()), diff --git a/cedar-policy-core/src/ast/expr_visitor.rs b/cedar-policy-core/src/ast/expr_visitor.rs index ee6acd3f26..c5adb9b353 100644 --- a/cedar-policy-core/src/ast/expr_visitor.rs +++ b/cedar-policy-core/src/ast/expr_visitor.rs @@ -55,7 +55,7 @@ pub trait ExprVisitor { match expr.expr_kind() { ExprKind::Lit(lit) => self.visit_literal(lit, loc), ExprKind::Var(var) => self.visit_var(*var, loc), - ExprKind::Slot(slot) => self.visit_slot(*slot, loc), + ExprKind::Slot(slot) => self.visit_slot(slot.clone(), loc), ExprKind::Unknown(unknown) => self.visit_unknown(unknown, loc), ExprKind::If { test_expr, diff --git a/cedar-policy-core/src/ast/generalized_slots_annotation.rs b/cedar-policy-core/src/ast/generalized_slots_annotation.rs new file mode 100644 index 0000000000..08cd85140a --- /dev/null +++ b/cedar-policy-core/src/ast/generalized_slots_annotation.rs @@ -0,0 +1,119 @@ +/* + * Copyright Cedar Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::collections::BTreeMap; + +use crate::ast::SlotId; +use crate::extensions::Extensions; +use crate::validator::{ + json_schema::Type as JSONSchemaType, types::Type as ValidatorType, RawName, SchemaError, + ValidatorSchema, +}; +use serde::{Deserialize, Serialize}; +use serde_with::serde_as; + +/// Struct which holds the type & position of a generalized slot +#[derive(Clone, Eq, PartialEq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)] +#[serde_as] +pub struct GeneralizedSlotsAnnotation(BTreeMap>); + +impl GeneralizedSlotsAnnotation { + /// Create a new empty `GeneralizedSlotsAnnotation` (with no slots) + pub fn new() -> Self { + Self(BTreeMap::new()) + } + + /// Get the type of the slot by key + pub fn get(&self, key: &SlotId) -> Option<&JSONSchemaType> { + self.0.get(key) + } + + /// Iterate over all pairs of slots and their types + pub fn iter(&self) -> impl Iterator)> { + self.0.iter() + } + + /// Tell if it's empty + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub(crate) fn into_validator_generalized_slots_annotation( + self, + schema: &ValidatorSchema, + ) -> Result { + let validator_generalized_slots_annotation: Result, SchemaError> = self + .0 + .into_iter() + .map(|(k, ty)| -> Result<_, SchemaError> { + Ok(( + k, + schema.json_schema_type_to_validator_type(ty, Extensions::all_available())?, + )) + }) + .collect(); + Ok(validator_generalized_slots_annotation?.into()) + } +} + +impl Default for GeneralizedSlotsAnnotation { + fn default() -> Self { + Self::new() + } +} + +impl FromIterator<(SlotId, JSONSchemaType)> for GeneralizedSlotsAnnotation { + fn from_iter)>>(iter: T) -> Self { + Self(BTreeMap::from_iter(iter)) + } +} + +impl From>> for GeneralizedSlotsAnnotation { + fn from(value: BTreeMap>) -> Self { + Self(value) + } +} + +#[derive(Clone, Eq, PartialEq, PartialOrd, Ord, Debug, Hash)] +pub(crate) struct ValidatorGeneralizedSlotsAnnotation(BTreeMap); + +impl FromIterator<(SlotId, ValidatorType)> for ValidatorGeneralizedSlotsAnnotation { + fn from_iter>(iter: T) -> Self { + Self(BTreeMap::from_iter(iter)) + } +} + +impl From> for ValidatorGeneralizedSlotsAnnotation { + fn from(value: BTreeMap) -> Self { + Self(value) + } +} + +impl Default for ValidatorGeneralizedSlotsAnnotation { + fn default() -> Self { + Self::new() + } +} + +impl ValidatorGeneralizedSlotsAnnotation { + pub(crate) fn new() -> Self { + Self(BTreeMap::new()) + } + + pub(crate) fn get(&self, slot: &SlotId) -> Option<&ValidatorType> { + self.0.get(slot) + } +} diff --git a/cedar-policy-core/src/ast/name.rs b/cedar-policy-core/src/ast/name.rs index 20f2ae7cf0..e65b28c638 100644 --- a/cedar-policy-core/src/ast/name.rs +++ b/cedar-policy-core/src/ast/name.rs @@ -21,6 +21,7 @@ use miette::Diagnostic; use ref_cast::RefCast; use regex::Regex; use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde_with::{serde_as, DisplayFromStr}; use smol_str::ToSmolStr; use std::collections::HashSet; use std::fmt::Display; @@ -283,9 +284,10 @@ impl<'de> Deserialize<'de> for InternalName { /// Clone is O(1). // This simply wraps a separate enum -- currently [`ValidSlotId`] -- in case we // want to generalize later -#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +#[serde_as] +#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)] #[serde(transparent)] -pub struct SlotId(pub(crate) ValidSlotId); +pub struct SlotId(#[serde_as(as = "DisplayFromStr")] pub(crate) ValidSlotId); impl SlotId { /// Get the slot for `principal` @@ -298,6 +300,11 @@ impl SlotId { Self(ValidSlotId::Resource) } + /// Create a `generalized slot` + pub fn generalized_slot(id: Id) -> Self { + Self(ValidSlotId::GeneralizedSlot(id)) + } + /// Check if a slot represents a principal pub fn is_principal(&self) -> bool { matches!(self, Self(ValidSlotId::Principal)) @@ -307,6 +314,11 @@ impl SlotId { pub fn is_resource(&self) -> bool { matches!(self, Self(ValidSlotId::Resource)) } + + /// Check if a slot represents a generalized slot + pub fn is_generalized_slot(&self) -> bool { + matches!(self, Self(ValidSlotId::GeneralizedSlot(_))) + } } impl From for SlotId { @@ -318,19 +330,26 @@ impl From for SlotId { } } +impl FromStr for SlotId { + type Err = ParseErrors; + + fn from_str(s: &str) -> Result { + crate::parser::parse_slot(s).map(SlotId) + } +} + impl std::fmt::Display for SlotId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } -/// Two possible variants for Slots -#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +/// Three possible variants for Slots +#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub(crate) enum ValidSlotId { - #[serde(rename = "?principal")] Principal, - #[serde(rename = "?resource")] Resource, + GeneralizedSlot(Id), // Slots for generalized templates, for more info see [RFC 98](https://github.com/cedar-policy/rfcs/pull/98). } impl std::fmt::Display for ValidSlotId { @@ -338,11 +357,20 @@ impl std::fmt::Display for ValidSlotId { let s = match self { ValidSlotId::Principal => "principal", ValidSlotId::Resource => "resource", + ValidSlotId::GeneralizedSlot(id) => id.as_ref(), }; write!(f, "?{s}") } } +impl FromStr for ValidSlotId { + type Err = ParseErrors; + + fn from_str(s: &str) -> Result { + crate::parser::parse_slot(s) + } +} + /// [`SlotId`] plus a source location #[derive(Educe, Debug, Clone)] #[educe(PartialEq, Eq, Hash)] diff --git a/cedar-policy-core/src/ast/policy.rs b/cedar-policy-core/src/ast/policy.rs index cce906c09b..300d0e0d58 100644 --- a/cedar-policy-core/src/ast/policy.rs +++ b/cedar-policy-core/src/ast/policy.rs @@ -15,9 +15,17 @@ */ use crate::ast::*; +use crate::entities::{conformance::typecheck_restricted_expr_against_schematype, SchemaType}; +use crate::extensions::Extensions; use crate::parser::{AsLocRef, IntoMaybeLoc, Loc, MaybeLoc}; +use crate::validator::{ + err::SchemaError, json_schema::Type as JSONSchemaType, + json_schema_type_to_validator_type_without_schema, types::Type as ValidatorType, RawName, + ValidatorSchema, +}; use annotation::{Annotation, Annotations}; use educe::Educe; +use generalized_slots_annotation::GeneralizedSlotsAnnotation; use itertools::Itertools; use miette::Diagnostic; use nonempty::{nonempty, NonEmpty}; @@ -53,6 +61,9 @@ cfg_tolerant_ast! { static DEFAULT_ANNOTATIONS: std::sync::LazyLock> = std::sync::LazyLock::new(|| Arc::new(Annotations::default())); + static DEFAULT_GENERALIZED_SLOTS_ANNOTATION: std::sync::LazyLock> = + std::sync::LazyLock::new(|| Arc::new(GeneralizedSlotsAnnotation::default())); + static DEFAULT_PRINCIPAL_CONSTRAINT: std::sync::LazyLock = std::sync::LazyLock::new(PrincipalConstraint::any); @@ -120,6 +131,7 @@ impl Template { id: PolicyID, loc: MaybeLoc, annotations: Annotations, + generalized_slots_annotation: GeneralizedSlotsAnnotation, effect: Effect, principal_constraint: PrincipalConstraint, action_constraint: ActionConstraint, @@ -130,6 +142,7 @@ impl Template { id, loc, annotations, + generalized_slots_annotation, effect, principal_constraint, action_constraint, @@ -154,6 +167,7 @@ impl Template { id: PolicyID, loc: MaybeLoc, annotations: Arc, + generalized_slots_annotation: Arc, effect: Effect, principal_constraint: PrincipalConstraint, action_constraint: ActionConstraint, @@ -164,6 +178,7 @@ impl Template { id, loc, annotations, + generalized_slots_annotation, effect, principal_constraint, action_constraint, @@ -238,6 +253,18 @@ impl Template { self.body.annotations_arc() } + /// Get all generalized_slots_annotation data. + pub fn generalized_slots_annotation( + &self, + ) -> impl Iterator)> { + self.body.generalized_slots_annotation() + } + + /// Get [`Arc`] owning the generalized slots annotation data. + pub fn generalized_slots_annotation_arc(&self) -> &Arc { + self.body.generalized_slots_annotation_arc() + } + /// Get the condition expression of this template. /// /// This will be a conjunction of the template's scope constraints (on @@ -266,15 +293,18 @@ impl Template { pub fn check_binding( template: &Template, values: &HashMap, + generalized_values: &HashMap, ) -> Result<(), LinkingError> { // Verify all slots bound let unbound = template .slots .iter() - .filter(|slot| !values.contains_key(&slot.id)) + .filter(|slot| { + !values.contains_key(&slot.id) && !generalized_values.contains_key(&slot.id) + }) .collect::>(); - let extra = values + let extra_values = values .iter() .filter_map(|(slot, _)| { if !template @@ -289,16 +319,211 @@ impl Template { }) .collect::>(); - if unbound.is_empty() && extra.is_empty() { + let extra_generalized_values = generalized_values + .iter() + .filter_map(|(slot, _)| { + if !template + .slots + .iter() + .any(|template_slot| template_slot.id == *slot) + { + Some(slot) + } else { + None + } + }) + .collect::>(); + + let invalid_keys_in_values = values + .iter() + .filter_map(|(slot, _)| { + if !(*slot == (SlotId::principal()) || *slot == (SlotId::resource())) { + Some(slot) + } else { + None + } + }) + .collect::>(); + + let invalid_keys_in_generalized_values = generalized_values + .iter() + .filter_map(|(slot, _)| { + if *slot == (SlotId::principal()) || *slot == (SlotId::resource()) { + Some(slot) + } else { + None + } + }) + .collect::>(); + + if unbound.is_empty() + && extra_values.is_empty() + && extra_generalized_values.is_empty() + && invalid_keys_in_values.is_empty() + && invalid_keys_in_generalized_values.is_empty() + { Ok(()) + } else if !(invalid_keys_in_values.is_empty()) { + Err(LinkingError::from_invalid_env( + invalid_keys_in_values.into_iter().cloned(), + )) + } else if !(invalid_keys_in_generalized_values.is_empty()) { + Err(LinkingError::from_invalid_generalized_env( + invalid_keys_in_generalized_values.into_iter().cloned(), + )) } else { Err(LinkingError::from_unbound_and_extras( - unbound.into_iter().map(|slot| slot.id), - extra.into_iter().copied(), + unbound.into_iter().map(|slot| slot.id.clone()), + extra_values + .into_iter() + .cloned() + .chain(extra_generalized_values.into_iter().cloned()), )) } } + /// Validates that the values provided for the generalized slots are of the types annotated + pub fn link_time_type_checking_with_schema( + template: &Template, + schema: &ValidatorSchema, + values: &HashMap, + generalized_values: &HashMap, + ) -> Result<(), LinkingError> { + let validator_generalized_slots_annotation = GeneralizedSlotsAnnotation::from_iter( + template + .generalized_slots_annotation() + .map(|(k, v)| (k.clone(), v.clone())), + ) + .into_validator_generalized_slots_annotation(schema)?; + + for (slot, entity_uid) in values { + let restricted_expr = &RestrictedExpr::val(entity_uid.clone()); + + if let Some(validator_type) = validator_generalized_slots_annotation.get(slot) { + let borrowed_restricted_expr = restricted_expr.as_borrowed(); + #[allow(clippy::expect_used)] + let schema_ty = &SchemaType::try_from(validator_type.clone()).expect( + "This should never happen as expected_ty is a statically annotated type", + ); + let extensions = Extensions::all_available(); + typecheck_restricted_expr_against_schematype( + borrowed_restricted_expr, + schema_ty, + extensions, + ) + .map_err(|_| { + LinkingError::ValueProvidedForSlotIsNotOfTypeSpecified { + slot: slot.clone(), + value: restricted_expr.clone(), + ty: validator_type.clone(), + } + })? + } + } + + for (slot, restricted_expr) in generalized_values { + let validator_type = validator_generalized_slots_annotation.get(slot).ok_or( + LinkingError::ArityError { + unbound_values: vec![slot.clone()], + extra_values: vec![], + }, + )?; + let borrowed_restricted_expr = restricted_expr.as_borrowed(); + #[allow(clippy::expect_used)] + let schema_ty = &SchemaType::try_from(validator_type.clone()) + .expect("This should never happen as expected_ty is a statically annotated type"); + let extensions = Extensions::all_available(); + typecheck_restricted_expr_against_schematype( + borrowed_restricted_expr, + schema_ty, + extensions, + ) + .map_err(|_| LinkingError::ValueProvidedForSlotIsNotOfTypeSpecified { + slot: slot.clone(), + value: restricted_expr.clone(), + ty: validator_type.clone(), + })? + } + Ok(()) + } + + /// Validates that the values provided for the generalized slots are of the types annotated + pub fn link_time_type_checking_without_schema( + template: &Template, + values: &HashMap, + generalized_values: &HashMap, + ) -> Result<(), LinkingError> { + let generalized_slots_annotation = GeneralizedSlotsAnnotation::from_iter( + template + .generalized_slots_annotation() + .map(|(k, v)| (k.clone(), v.clone())), + ); + + for (slot, entity_uid) in values { + let restricted_expr = &RestrictedExpr::val(entity_uid.clone()); + + if let Some(raw_type) = generalized_slots_annotation.get(slot) { + let extensions = Extensions::all_available(); + + let validator_type = json_schema_type_to_validator_type_without_schema( + raw_type.clone(), + extensions, + )?; + + let borrowed_restricted_expr = restricted_expr.as_borrowed(); + #[allow(clippy::expect_used)] + let schema_ty = &SchemaType::try_from(validator_type.clone()).expect( + "This should never happen as expected_ty is a statically annotated type", + ); + + typecheck_restricted_expr_against_schematype( + borrowed_restricted_expr, + schema_ty, + extensions, + ) + .map_err(|_| { + LinkingError::ValueProvidedForSlotIsNotOfTypeSpecified { + slot: slot.clone(), + value: restricted_expr.clone(), + ty: validator_type.clone(), + } + })? + } + } + + for (slot, restricted_expr) in generalized_values { + let raw_type = + generalized_slots_annotation + .get(slot) + .ok_or(LinkingError::ArityError { + unbound_values: vec![slot.clone()], + extra_values: vec![], + })?; + let extensions = Extensions::all_available(); + + let validator_type = + json_schema_type_to_validator_type_without_schema(raw_type.clone(), extensions)?; + + let borrowed_restricted_expr = restricted_expr.as_borrowed(); + #[allow(clippy::expect_used)] + let schema_ty = &SchemaType::try_from(validator_type.clone()) + .expect("This should never happen as expected_ty is a statically annotated type"); + + typecheck_restricted_expr_against_schematype( + borrowed_restricted_expr, + schema_ty, + extensions, + ) + .map_err(|_| LinkingError::ValueProvidedForSlotIsNotOfTypeSpecified { + slot: slot.clone(), + value: restricted_expr.clone(), + ty: validator_type.clone(), + })? + } + + Ok(()) + } + /// Attempt to create a template-linked policy from this template. /// This will fail if values for all open slots are not given. /// `new_instance_id` is the `PolicyId` for the created template-linked policy. @@ -306,10 +531,28 @@ impl Template { template: Arc