diff --git a/examples/src/main/java/io/opentdf/platform/GetEntitlements.java b/examples/src/main/java/io/opentdf/platform/GetEntitlements.java index 0f0ab735..f9479577 100644 --- a/examples/src/main/java/io/opentdf/platform/GetEntitlements.java +++ b/examples/src/main/java/io/opentdf/platform/GetEntitlements.java @@ -7,7 +7,6 @@ import io.opentdf.platform.sdk.*; import java.util.Collections; -import java.util.concurrent.ExecutionException; import java.util.List; diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/Autoconfigure.java b/sdk/src/main/java/io/opentdf/platform/sdk/Autoconfigure.java index b60d0725..e17b69cd 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/Autoconfigure.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/Autoconfigure.java @@ -1,12 +1,12 @@ package io.opentdf.platform.sdk; import com.connectrpc.ResponseMessageKt; +import io.opentdf.platform.policy.Algorithm; import io.opentdf.platform.policy.Attribute; import io.opentdf.platform.policy.AttributeRuleTypeEnum; import io.opentdf.platform.policy.AttributeValueSelector; -import io.opentdf.platform.policy.KasPublicKey; -import io.opentdf.platform.policy.KasPublicKeyAlgEnum; import io.opentdf.platform.policy.KeyAccessServer; +import io.opentdf.platform.policy.SimpleKasKey; import io.opentdf.platform.policy.Value; import io.opentdf.platform.policy.attributes.AttributesServiceClientInterface; import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsRequest; @@ -14,6 +14,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.io.UnsupportedEncodingException; import java.net.URLDecoder; @@ -27,6 +28,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.StringJoiner; import java.util.function.Supplier; @@ -55,45 +57,80 @@ class RuleType { */ public class Autoconfigure { - public static Logger logger = LoggerFactory.getLogger(Autoconfigure.class); + private static Logger logger = LoggerFactory.getLogger(Autoconfigure.class); - public static class KeySplitStep { - public String kas; - public String splitID; + private Autoconfigure() { + // Prevent instantiation, this class is a utility class that is only used statically + } - public KeySplitStep(String kas, String splitId) { - this.kas = kas; - this.splitID = splitId; - } + static class KeySplitTemplate { + final String kas; + final String splitID; + final String kid; + final KeyType keyType; @Override public String toString() { - return "KeySplitStep{kas=" + this.kas + ", splitID=" + this.splitID + "}"; + return "KeySplitTemplate{" + + "kas='" + kas + '\'' + + ", splitID='" + splitID + '\'' + + ", kid='" + kid + '\'' + + ", keyType=" + keyType + + '}'; } @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null || !(obj instanceof KeySplitStep)) { - return false; - } - KeySplitStep ss = (KeySplitStep) obj; - if ((this.kas.equals(ss.kas)) && (this.splitID.equals(ss.splitID))) { - return true; - } - return false; + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + KeySplitTemplate that = (KeySplitTemplate) o; + return Objects.equals(kas, that.kas) && Objects.equals(splitID, that.splitID) && Objects.equals(kid, that.kid) && keyType == that.keyType; + } + + @Override + public int hashCode() { + return Objects.hash(kas, splitID, kid, keyType); + } + + public KeySplitTemplate(String kas, String splitID, String kid, KeyType keyType) { + this.kas = kas; + this.splitID = splitID; + this.kid = kid; + this.keyType = keyType; + } + } + + public static class KeySplitStep { + final String kas; + final String splitID; + + KeySplitStep(String kas, String splitId) { + this.kas = Objects.requireNonNull(kas); + this.splitID = Objects.requireNonNull(splitId); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + KeySplitStep that = (KeySplitStep) o; + return Objects.equals(kas, that.kas) && Objects.equals(splitID, that.splitID); } @Override public int hashCode() { return Objects.hash(kas, splitID); } + + @Override + public String toString() { + return "KeySplitStep{" + + "kas='" + kas + '\'' + + ", splitID='" + splitID + '\'' + + '}'; + } } // Utility class for an attribute name FQN. - public static class AttributeNameFQN { + static class AttributeNameFQN { private final String url; private final String key; @@ -156,7 +193,7 @@ public String name() throws AutoConfigureException { } // Utility class for an attribute value FQN. - public static class AttributeValueFQN { + static class AttributeValueFQN { private final String url; private final String key; @@ -203,11 +240,11 @@ public int hashCode() { return Objects.hash(key); } - public String getKey() { + String getKey() { return key; } - public String authority() { + String authority() { Pattern pattern = Pattern.compile("^(https?://[\\w./-]+)/attr/\\S*/value/\\S*$"); Matcher matcher = pattern.matcher(url); if (!matcher.find()) { @@ -216,7 +253,7 @@ public String authority() { return matcher.group(1); } - public AttributeNameFQN prefix() throws AutoConfigureException { + AttributeNameFQN prefix() throws AutoConfigureException { Pattern pattern = Pattern.compile("^(https?://[\\w./-]+/attr/\\S*)/value/\\S*$"); Matcher matcher = pattern.matcher(url); if (!matcher.find()) { @@ -225,7 +262,7 @@ public AttributeNameFQN prefix() throws AutoConfigureException { return new AttributeNameFQN(matcher.group(1)); } - public String value() { + String value() { Pattern pattern = Pattern.compile("^https?://[\\w./-]+/attr/\\S*/value/(\\S*)$"); Matcher matcher = pattern.matcher(url); if (!matcher.find()) { @@ -238,7 +275,7 @@ public String value() { } } - public String name() { + String name() { Pattern pattern = Pattern.compile("^https?://[\\w./-]+/attr/(\\S*)/value/\\S*$"); Matcher matcher = pattern.matcher(url); if (!matcher.find()) { @@ -246,13 +283,13 @@ public String name() { } try { return URLDecoder.decode(matcher.group(1), StandardCharsets.UTF_8.name()); - } catch (UnsupportedEncodingException | IllegalArgumentException e) { - throw new RuntimeException("invalid attributeInstance", e); + } catch (UnsupportedEncodingException | IllegalArgumentException e) { + throw new RuntimeException("illegal attribute instance", e); } } } - public static class KeyAccessGrant { + static class KeyAccessGrant { public Attribute attr; public List kases; @@ -266,40 +303,104 @@ public KeyAccessGrant(Attribute attr, List kases) { static class Granter { private final List policy; private final Map grants = new HashMap<>(); + private final Map> mappedKeys = new HashMap<>(); + private boolean hasGrants = false; + private boolean hasMappedKeys = false; - public Granter(List policy) { + Granter(List policy) { this.policy = policy; } - public Map getGrants() { - return new HashMap(grants); + Map getGrants() { + return new HashMap<>(grants); } - public List getPolicy() { + List getPolicy() { return policy; } - public void addGrant(AttributeValueFQN fqn, String kas, Attribute attr) { - grants.computeIfAbsent(fqn.key, k -> new KeyAccessGrant(attr, new ArrayList<>())).kases.add(kas); - } + boolean addAllGrants(AttributeValueFQN fqn, List granted, List mapped, Attribute attr) { + boolean foundMappedKey = false; + for (var mappedKey: mapped) { + foundMappedKey = true; + mappedKeys.computeIfAbsent(fqn.key, k -> new ArrayList<>()).add(Config.KASInfo.fromSimpleKasKey(mappedKey)); + grants.computeIfAbsent(fqn.key, k -> new KeyAccessGrant(attr, new ArrayList<>())).kases.add(mappedKey.getKasUri()); + } - public void addAllGrants(AttributeValueFQN fqn, List gs, Attribute attr) { - if (gs.isEmpty()) { - grants.putIfAbsent(fqn.key, new KeyAccessGrant(attr, new ArrayList<>())); - } else { - for (KeyAccessServer g : gs) { - if (g != null) { - addGrant(fqn, g.getUri(), attr); + if (foundMappedKey) { + hasMappedKeys = true; + return true; + } + + boolean foundGrantedKey = false; + for (var grantedKey: granted) { + foundGrantedKey = true; + grants.computeIfAbsent(fqn.key, k -> new KeyAccessGrant(attr, new ArrayList<>())).kases.add(grantedKey.getUri()); + if (!grantedKey.getKasKeysList().isEmpty()) { + for (var kas : grantedKey.getKasKeysList()) { + mappedKeys.computeIfAbsent(fqn.key, k -> new ArrayList<>()).add(Config.KASInfo.fromSimpleKasKey(kas)); } + continue; } + var cachedGrantKeys = grantedKey.getPublicKey().getCached().getKeysList(); + + if (logger.isDebugEnabled()) { + logger.debug("found {} keys cached in policy service", cachedGrantKeys.size()); + } + + for (var cachedGrantKey: cachedGrantKeys) { + var mappedKey = new Config.KASInfo(); + mappedKey.URL = grantedKey.getUri(); + mappedKey.KID = cachedGrantKey.getKid(); + mappedKey.Algorithm = KeyType.fromPublicKeyAlgorithm(cachedGrantKey.getAlg()).toString(); + mappedKey.PublicKey = cachedGrantKey.getPem(); + mappedKey.Default = false; + mappedKeys.computeIfAbsent(fqn.key, k -> new ArrayList<>()).add(mappedKey); + } + } + + if (!grants.containsKey(fqn.key)) { + grants.put(fqn.key, new KeyAccessGrant(attr, new ArrayList<>())); + } + + if (foundGrantedKey) { + hasGrants = true; } + return foundGrantedKey; } - public KeyAccessGrant byAttribute(AttributeValueFQN fqn) { + KeyAccessGrant byAttribute(AttributeValueFQN fqn) { return grants.get(fqn.key); } - public List plan(List defaultKas, Supplier genSplitID) + List getSplits(List defaultKases, Supplier genSplitID, Supplier> baseKeySupplier) throws AutoConfigureException { + if (hasMappedKeys) { + logger.debug("generating plan from mapped keys"); + return planFromAttributes(genSplitID); + } + if (hasGrants) { + logger.debug("generating plan from grants"); + return planUsingGrants(genSplitID); + } + + var baseKey = baseKeySupplier.get(); + if (baseKey.isPresent()) { + var key = baseKey.get(); + String kas = key.getKasUri(); + String splitID = ""; + String kid = key.getPublicKey().getKid(); + Algorithm algorithm = key.getPublicKey().getAlgorithm(); + return Collections.singletonList(new KeySplitTemplate(kas, splitID, kid, KeyType.fromAlgorithm(algorithm))); + } + + logger.warn("no grants or mapped keys found, generating plan from default KASes. this is deprecated"); + // this is a little bit weird because we don't take into account the KIDs here. This is the way + // that it works in the go SDK but it seems a bit odd + return generatePlanFromDefaultKases(defaultKases, genSplitID); + } + + @Nonnull + List planUsingGrants(Supplier genSplitID) throws AutoConfigureException { AttributeBooleanExpression b = constructAttributeBoolean(); BooleanKeyExpression k = insertKeysForAttribute(b); @@ -310,30 +411,60 @@ public List plan(List defaultKas, Supplier genSpli k = k.reduce(); int l = k.size(); if (l == 0) { - // default behavior: split key across all default KAS - if (defaultKas.isEmpty()) { - throw new AutoConfigureException("no default KAS specified; required for grantless plans"); - } else if (defaultKas.size() == 1) { - return Collections.singletonList(new KeySplitStep(defaultKas.get(0), "")); - } else { - List result = new ArrayList<>(); - for (String kas : defaultKas) { - result.add(new KeySplitStep(kas, genSplitID.get())); - } - return result; + throw new AutoConfigureException("generated an empty plan"); + } + + List steps = new ArrayList<>(); + for (KeyClause v : k.values) { + String splitID = (l > 1) ? genSplitID.get() : ""; + for (PublicKeyInfo o : v.values) { + // grants only have KAS URLs, no KIDs or algorithms + steps.add(new KeySplitTemplate(o.kas, splitID, null, null)); } } + return steps; + } + + @Nonnull + List planFromAttributes(Supplier genSplitID) + throws AutoConfigureException { + AttributeBooleanExpression b = constructAttributeBoolean(); + BooleanKeyExpression k = assignKeysTo(b); + if (k == null) { + throw new AutoConfigureException("Error assigning keys to attribute"); + } + + k = k.reduce(); + int l = k.size(); + if (l == 0) { + return Collections.emptyList(); + } - List steps = new ArrayList<>(); + List steps = new ArrayList<>(); for (KeyClause v : k.values) { String splitID = (l > 1) ? genSplitID.get() : ""; for (PublicKeyInfo o : v.values) { - steps.add(new KeySplitStep(o.kas, splitID)); + KeyType keyType = o.algorithm != null ? KeyType.fromString(o.algorithm) : null; + steps.add(new KeySplitTemplate(o.kas, splitID, o.kid, keyType)); } } return steps; } + static List generatePlanFromDefaultKases(List defaultKas, Supplier genSplitID) { + if (defaultKas.isEmpty()) { + throw new AutoConfigureException("no default KAS specified; required for grantless plans"); + } else if (defaultKas.size() == 1) { + return Collections.singletonList(new KeySplitTemplate(defaultKas.get(0), "", null, null)); + } else { + List result = new ArrayList<>(); + for (String kas : defaultKas) { + result.add(new KeySplitTemplate(kas, genSplitID.get(), null, null)); + } + return result; + } + } + BooleanKeyExpression insertKeysForAttribute(AttributeBooleanExpression e) throws AutoConfigureException { List kcs = new ArrayList<>(e.must.size()); @@ -361,13 +492,53 @@ BooleanKeyExpression insertKeysForAttribute(AttributeBooleanExpression e) throws logger.warn("Unknown attribute rule type: " + clause); } - KeyClause kc = new KeyClause(op, kcv); - kcs.add(kc); + kcs.add(new KeyClause(op, kcv)); } return new BooleanKeyExpression(kcs); } + BooleanKeyExpression assignKeysTo(AttributeBooleanExpression e) { + var keyClauses = new ArrayList(); + for (var clause : e.must) { + ArrayList keys = new ArrayList<>(); + if (clause.values.isEmpty()) { + logger.warn("No values found for attribute {}", clause.def.getFqn()); + continue; + } + for (var value : clause.values) { + var mapped = mappedKeys.get(value.key); + if (mapped == null) { + logger.warn("No keys found for attribute value {}", value); + continue; + } + for (var kasInfo : mapped) { + if (kasInfo.URL == null || kasInfo.URL.isEmpty()) { + logger.warn("No KAS URL found for attribute value {}", value); + continue; + } + keys.add(new PublicKeyInfo(kasInfo.URL, kasInfo.KID, kasInfo.Algorithm)); + } + } + + String op = ruleToOperator(clause.def.getRule()); + if (op.equals(RuleType.UNSPECIFIED)) { + logger.warn("Unknown attribute rule type {}", op); + } + + keyClauses.add(new KeyClause(op, keys)); + } + + return new BooleanKeyExpression(keyClauses); + } + + /** + * Constructs an AttributeBooleanExpression from the policy, splitting each attribute + * into its own clause. Each clause contains the attribute definition and a list of + * values. + * @return + * @throws AutoConfigureException + */ AttributeBooleanExpression constructAttributeBoolean() throws AutoConfigureException { Map prefixes = new HashMap<>(); List sortedPrefixes = new ArrayList<>(); @@ -378,7 +549,7 @@ AttributeBooleanExpression constructAttributeBoolean() throws AutoConfigureExcep clause.values.add(aP); } else if (byAttribute(aP) != null) { var x = new SingleAttributeClause(byAttribute(aP).attr, - new ArrayList(Arrays.asList(aP))); + new ArrayList<>(Arrays.asList(aP))); prefixes.put(a.getKey(), x); sortedPrefixes.add(a.getKey()); } @@ -391,39 +562,6 @@ AttributeBooleanExpression constructAttributeBoolean() throws AutoConfigureExcep return new AttributeBooleanExpression(must); } - static class AttributeMapping { - - private Map dict; - - public AttributeMapping() { - this.dict = new HashMap<>(); - } - - public void put(Attribute ad) throws AutoConfigureException { - if (this.dict == null) { - this.dict = new HashMap<>(); - } - - AttributeNameFQN prefix = new AttributeNameFQN(ad.getFqn()); - - if (this.dict.containsKey(prefix)) { - throw new AutoConfigureException("Attribute prefix already found: [" + prefix.toString() + "]"); - } - - this.dict.put(prefix, ad); - } - - public Attribute get(AttributeNameFQN prefix) throws AutoConfigureException { - Attribute ad = this.dict.get(prefix); - if (ad == null) { - throw new AutoConfigureException("Unknown attribute type: [" + prefix.toString() + "], not in [" - + this.dict.keySet().toString() + "]"); - } - return ad; - } - - } - static class SingleAttributeClause { private Attribute def; @@ -435,9 +573,9 @@ public SingleAttributeClause(Attribute def, List values) { } } - class AttributeBooleanExpression { + static class AttributeBooleanExpression { - private List must; + private final List must; public AttributeBooleanExpression(List must) { this.must = must; @@ -475,28 +613,65 @@ public String toString() { } return sb.toString(); } - } - public class PublicKeyInfo { - private String kas; + static class PublicKeyInfo implements Comparable { + final String kas; + final String kid; + final String algorithm; - public PublicKeyInfo(String kas) { - this.kas = kas; + PublicKeyInfo(String kas) { + this(kas, null, null); } - public String getKas() { + PublicKeyInfo(String kas, String kid, String algorithm) { + this.kas = Objects.requireNonNull(kas); + this.kid = kid; + this.algorithm = algorithm; + } + + String getKas() { return kas; } - public void setKas(String kas) { - this.kas = kas; + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + PublicKeyInfo that = (PublicKeyInfo) o; + return Objects.equals(kas, that.kas) && Objects.equals(kid, that.kid) && Objects.equals(algorithm, that.algorithm); + } + + @Override + public int hashCode() { + return Objects.hash(kas, kid, algorithm); + } + + @Override + public int compareTo(PublicKeyInfo o) { + if (this.kas.compareTo(o.kas) != 0) { + return this.kas.compareTo(o.kas); + } + if ((this.kid == null) != (o.kid == null)) { + return this.kid == null ? -1 : 1; + } + if (this.kid != null) { + if (this.kid.compareTo(o.kid) != 0) { + return this.kid.compareTo(o.kid); + } + } + if ((this.algorithm == null) != (o.algorithm == null)) { + return this.algorithm == null ? -1 : 1; + } + if (this.algorithm != null) { + return this.algorithm.compareTo(o.algorithm); + } + return 0; } } - public class KeyClause { - private String operator; - private List values; + static class KeyClause { + private final String operator; + private final List values; public KeyClause(String operator, List values) { this.operator = operator; @@ -531,8 +706,8 @@ public String toString() { } } - public class BooleanKeyExpression { - private List values; + static class BooleanKeyExpression { + private final List values; public BooleanKeyExpression(List values) { this.values = values; @@ -572,7 +747,7 @@ public BooleanKeyExpression reduce() { continue; } Disjunction terms = new Disjunction(); - terms.add(k.getKas()); + terms.add(k); if (!within(conjunction, terms)) { conjunction.add(terms); } @@ -584,25 +759,22 @@ public BooleanKeyExpression reduce() { } List newValues = new ArrayList<>(); - for (List d : conjunction) { + for (List d : conjunction) { List pki = new ArrayList<>(); - for (String k : d) { - pki.add(new PublicKeyInfo(k)); - } + pki.addAll(d); newValues.add(new KeyClause(RuleType.ANY_OF, pki)); } return new BooleanKeyExpression(newValues); } public Disjunction sortedNoDupes(List l) { - Set set = new HashSet<>(); + Set set = new HashSet<>(); Disjunction list = new Disjunction(); for (PublicKeyInfo e : l) { - String kas = e.getKas(); - if (!kas.equals(RuleType.EMPTY_TERM) && !set.contains(kas)) { - set.add(kas); - list.add(kas); + if (!Objects.equals(e.getKas(), RuleType.EMPTY_TERM) && !set.contains(e)) { + set.add(e); + list.add(e); } } @@ -612,7 +784,7 @@ public Disjunction sortedNoDupes(List l) { } - class Disjunction extends ArrayList { + static class Disjunction extends ArrayList { public boolean less(Disjunction r) { int m = Math.min(this.size(), r.size()); @@ -683,7 +855,7 @@ public static String ruleToOperator(AttributeRuleTypeEnum e) { // Given a policy (list of data attributes or tags), // get a set of grants from attribute values to KASes. // Unlike `NewGranterFromService`, this works offline. - public static Granter newGranterFromAttributes(Value... attrValues) throws AutoConfigureException { + static Granter newGranterFromAttributes(KASKeyCache keyCache, Value... attrValues) throws AutoConfigureException { var attrsAndValues = Arrays.stream(attrValues).map(v -> { if (!v.hasAttribute()) { throw new AutoConfigureException("tried to use an attribute that is not initialized"); @@ -694,11 +866,11 @@ public static Granter newGranterFromAttributes(Value... attrValues) throws AutoC .build(); }).collect(Collectors.toList()); - return getGranter(null, attrsAndValues); + return getGranter(keyCache, attrsAndValues); } // Gets a list of directory of KAS grants for a list of attribute FQNs - public static Granter newGranterFromService(AttributesServiceClientInterface as, KASKeyCache keyCache, AttributeValueFQN... fqns) throws AutoConfigureException { + static Granter newGranterFromService(AttributesServiceClientInterface as, KASKeyCache keyCache, AttributeValueFQN... fqns) throws AutoConfigureException { GetAttributeValuesByFqnsRequest request = GetAttributeValuesByFqnsRequest.newBuilder() .addAllFqns(Arrays.stream(fqns).map(AttributeValueFQN::toString).collect(Collectors.toList())) .setWithValue(AttributeValueSelector.newBuilder().setWithKeyAccessGrants(true).build()) @@ -711,82 +883,57 @@ public static Granter newGranterFromService(AttributesServiceClientInterface as, return getGranter(keyCache, new ArrayList<>(av.getFqnAttributeValuesMap().values())); } - private static List getGrants(GetAttributeValuesByFqnsResponse.AttributeAndValue attributeAndValue) { - var val = attributeAndValue.getValue(); - var attribute = attributeAndValue.getAttribute(); - if (!val.getGrantsList().isEmpty()) { - if (logger.isDebugEnabled()) { - logger.debug("adding grants from attribute value [{}]: {}", val.getFqn(), val.getGrantsList().stream().map(KeyAccessServer::getUri).collect(Collectors.toList())); - } - return val.getGrantsList(); - } else if (!attribute.getGrantsList().isEmpty()) { - var attributeGrants = attribute.getGrantsList(); - if (logger.isDebugEnabled()) { - logger.debug("adding grants from attribute [{}]: {}", attribute.getFqn(), attributeGrants.stream().map(KeyAccessServer::getId).collect(Collectors.toList())); - } - return attributeGrants; - } else if (!attribute.getNamespace().getGrantsList().isEmpty()) { - var nsGrants = attribute.getNamespace().getGrantsList(); - if (logger.isDebugEnabled()) { - logger.debug("adding grants from namespace [{}]: [{}]", attribute.getNamespace().getName(), nsGrants.stream().map(KeyAccessServer::getId).collect(Collectors.toList())); - } - return nsGrants; - } else { - // this is needed to mark the fact that we have an empty - if (logger.isDebugEnabled()) { - logger.debug("didn't find any grants on value, attribute, or namespace for attribute value [{}]", val.getFqn()); - } - return Collections.emptyList(); + static Autoconfigure.Granter createGranter(SDK.Services services, Config.TDFConfig tdfConfig) { + Autoconfigure.Granter granter = new Autoconfigure.Granter(new ArrayList<>()); + if (tdfConfig.attributeValues != null && !tdfConfig.attributeValues.isEmpty()) { + granter = Autoconfigure.newGranterFromAttributes(services.kas().getKeyCache(), tdfConfig.attributeValues.toArray(new Value[0])); + } else if (tdfConfig.attributes != null && !tdfConfig.attributes.isEmpty()) { + granter = Autoconfigure.newGranterFromService(services.attributes(), services.kas().getKeyCache(), + tdfConfig.attributes.toArray(new Autoconfigure.AttributeValueFQN[0])); } - + return granter; } private static Granter getGranter(@Nullable KASKeyCache keyCache, List values) { - Granter grants = new Granter(values.stream().map(GetAttributeValuesByFqnsResponse.AttributeAndValue::getValue).map(Value::getFqn).map(AttributeValueFQN::new).collect(Collectors.toList())); + List attributeValues = values.stream() + .map(GetAttributeValuesByFqnsResponse.AttributeAndValue::getValue) + .map(Value::getFqn) + .map(AttributeValueFQN::new) + .collect(Collectors.toList()); + Granter grants = new Granter(attributeValues); for (var attributeAndValue: values) { - var attributeGrants = getGrants(attributeAndValue); String fqnstr = attributeAndValue.getValue().getFqn(); AttributeValueFQN fqn = new AttributeValueFQN(fqnstr); - grants.addAllGrants(fqn, attributeGrants, attributeAndValue.getAttribute()); - if (keyCache != null) { - storeKeysToCache(attributeGrants, keyCache); - } - } - - return grants; - } + var value = attributeAndValue.getValue(); + var attribute = attributeAndValue.getAttribute(); + var namespace = attribute.getNamespace(); - static void storeKeysToCache(List kases, KASKeyCache keyCache) { - for (KeyAccessServer kas : kases) { - List keys = kas.getPublicKey().getCached().getKeysList(); - if (keys.isEmpty()) { - logger.debug("No cached key in policy service for KAS: " + kas.getUri()); + if (grants.addAllGrants(fqn, value.getGrantsList(), value.getKasKeysList(), attribute)) { + storeKeysToCache(value.getGrantsList(), value.getKasKeysList(), keyCache); + continue; + } + if (grants.addAllGrants(fqn, attribute.getGrantsList(), attribute.getKasKeysList(), attribute)) { + storeKeysToCache(attribute.getGrantsList(), attribute.getKasKeysList(), keyCache); continue; } - for (KasPublicKey ki : keys) { - Config.KASInfo kasInfo = new Config.KASInfo(); - kasInfo.URL = kas.getUri(); - kasInfo.KID = ki.getKid(); - kasInfo.Algorithm = algProto2String(ki.getAlg()); - kasInfo.PublicKey = ki.getPem(); - keyCache.store(kasInfo); + if (grants.addAllGrants(fqn, namespace.getGrantsList(), namespace.getKasKeysList(), attribute)) { + storeKeysToCache(namespace.getGrantsList(), namespace.getKasKeysList(), keyCache); } } + + return grants; } - private static String algProto2String(KasPublicKeyAlgEnum e) { - switch (e) { - case KAS_PUBLIC_KEY_ALG_ENUM_EC_SECP256R1: - return "ec:secp256r1"; - case KAS_PUBLIC_KEY_ALG_ENUM_RSA_2048: - return "rsa:2048"; - case KAS_PUBLIC_KEY_ALG_ENUM_UNSPECIFIED: - default: - return ""; + static void storeKeysToCache(List kases, List kasKeys, KASKeyCache keyCache) { + if (keyCache == null) { + return; } + for (var kas : kases) { + Config.KASInfo.fromKeyAccessServer(kas).forEach(keyCache::store); + } + kasKeys.stream().map(Config.KASInfo::fromSimpleKasKey).forEach(keyCache::store); } - } diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/Config.java b/sdk/src/main/java/io/opentdf/platform/sdk/Config.java index 0a1f4a46..88757be0 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/Config.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/Config.java @@ -1,13 +1,18 @@ package io.opentdf.platform.sdk; +import io.opentdf.platform.policy.KeyAccessServer; +import io.opentdf.platform.policy.SimpleKasKey; import io.opentdf.platform.policy.Value; import io.opentdf.platform.sdk.Autoconfigure.AttributeValueFQN; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.net.URI; import java.net.URISyntaxException; import java.util.*; import java.util.function.Consumer; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * Configuration class for setting various configurations related to TDF. @@ -22,6 +27,7 @@ public class Config { public static final String KAS_PUBLIC_KEY_PATH = "/kas_public_key"; public static final String DEFAULT_MIME_TYPE = "application/octet-stream"; public static final int MAX_COLLECTION_ITERATION = (1 << 24) - 1; + private static Logger logger = LoggerFactory.getLogger(Config.class); public enum TDFFormat { JSONFormat, @@ -33,8 +39,6 @@ public enum IntegrityAlgorithm { GMAC } - public static final int K_HTTP_OK = 200; - public static class KASInfo implements Cloneable { public String URL; public String PublicKey; @@ -71,6 +75,36 @@ public String toString() { } return sb.append("}").toString(); } + + public static List fromKeyAccessServer(KeyAccessServer kas) { + var keys = kas.getPublicKey().getCached().getKeysList(); + if (keys.isEmpty()) { + logger.warn("Invalid KAS key mapping for kas [{}]: publicKey is empty", kas.getUri()); + return Collections.emptyList(); + } + return keys.stream().flatMap(ki -> { + if (ki.getPem().isEmpty()) { + logger.warn("Invalid KAS key mapping for kas [{}]: publicKey PEM is empty", kas.getUri()); + return Stream.empty(); + } + Config.KASInfo kasInfo = new Config.KASInfo(); + kasInfo.URL = kas.getUri(); + kasInfo.KID = ki.getKid(); + kasInfo.Algorithm = KeyType.fromPublicKeyAlgorithm(ki.getAlg()).toString(); + kasInfo.PublicKey = ki.getPem(); + return Stream.of(kasInfo); + }).collect(Collectors.toList()); + } + + public static KASInfo fromSimpleKasKey(SimpleKasKey ki) { + Config.KASInfo kasInfo = new Config.KASInfo(); + kasInfo.URL = ki.getKasUri(); + kasInfo.KID = ki.getPublicKey().getKid(); + kasInfo.Algorithm = KeyType.fromAlgorithm(ki.getPublicKey().getAlgorithm()).toString(); + kasInfo.PublicKey = ki.getPublicKey().getPem(); + + return kasInfo; + } } public static class AssertionVerificationKeys { @@ -239,6 +273,11 @@ public static Consumer withKasInformation(KASInfo... kasInfoList) { }; } + /** + * Deprecated since 9.1.0, will be removed. To produce key shares use + * the key mapping feature + */ + @Deprecated(since = "9.1.0", forRemoval = true) public static Consumer withSplitPlan(Autoconfigure.KeySplitStep... p) { return (TDFConfig config) -> { config.splitPlan = new ArrayList<>(Arrays.asList(p)); diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java b/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java index beff0e5b..7ab09283 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java @@ -91,7 +91,7 @@ public KASInfo getECPublicKey(Config.KASInfo kasInfo, NanoTDFType.ECCurve curve) @Override public Config.KASInfo getPublicKey(Config.KASInfo kasInfo) { - Config.KASInfo cachedValue = this.kasKeyCache.get(kasInfo.URL, kasInfo.Algorithm); + Config.KASInfo cachedValue = this.kasKeyCache.get(kasInfo.URL, kasInfo.Algorithm, kasInfo.KID); if (cachedValue != null) { return cachedValue; } diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/KASKeyCache.java b/sdk/src/main/java/io/opentdf/platform/sdk/KASKeyCache.java index 5879dd05..75bae93c 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/KASKeyCache.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/KASKeyCache.java @@ -7,6 +7,7 @@ import java.time.temporal.ChronoUnit; import java.util.HashMap; import java.util.Map; +import java.util.Objects; /** * Class representing a cache for KAS (Key Access Server) information. @@ -24,14 +25,14 @@ public void clear() { this.cache = new HashMap<>(); } - public Config.KASInfo get(String url, String algorithm) { - log.debug("retrieving kasinfo for url = [{}], algorithm = [{}]", url, algorithm); - KASKeyRequest cacheKey = new KASKeyRequest(url, algorithm); + public Config.KASInfo get(String url, String algorithm, String kid) { + log.debug("retrieving kasinfo for url = [{}], algorithm = [{}], kid = [{}]", url, algorithm, kid); + KASKeyRequest cacheKey = new KASKeyRequest(url, algorithm, kid); LocalDateTime now = LocalDateTime.now(); TimeStampedKASInfo cachedValue = cache.get(cacheKey); if (cachedValue == null) { - log.debug("didn't find kasinfo for url = [{}], algorithm = [{}]", url, algorithm); + log.debug("didn't find kasinfo for key= [{}]", cacheKey); return null; } @@ -49,7 +50,7 @@ public Config.KASInfo get(String url, String algorithm) { public void store(Config.KASInfo kasInfo) { log.debug("storing kasInfo into the cache {}", kasInfo); - KASKeyRequest cacheKey = new KASKeyRequest(kasInfo.URL, kasInfo.Algorithm); + KASKeyRequest cacheKey = new KASKeyRequest(kasInfo.URL, kasInfo.Algorithm, kasInfo.KID); cache.put(cacheKey, new TimeStampedKASInfo(kasInfo, LocalDateTime.now())); } } @@ -85,30 +86,34 @@ public TimeStampedKASInfo(Config.KASInfo kasInfo, LocalDateTime timestamp) { class KASKeyRequest { private String url; private String algorithm; + private String kid; - public KASKeyRequest(String url, String algorithm) { + public KASKeyRequest(String url, String algorithm, String kid) { this.url = url; this.algorithm = algorithm; + this.kid = kid; } - // Override equals and hashCode to ensure proper functioning of the HashMap @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || !(o instanceof KASKeyRequest)) return false; + if (o == null || getClass() != o.getClass()) return false; KASKeyRequest that = (KASKeyRequest) o; - if (algorithm == null){ - return url.equals(that.url); - } - return url.equals(that.url) && algorithm.equals(that.algorithm); + return Objects.equals(url, that.url) && Objects.equals(algorithm, that.algorithm) && Objects.equals(kid, that.kid); } @Override public int hashCode() { - int result = 31 * url.hashCode(); - if (algorithm != null) { - result = result + algorithm.hashCode(); - } - return result; + return Objects.hash(url, algorithm, kid); + } + + @Override + public String toString() { + return "KASKeyRequest{" + + "url='" + url + '\'' + + ", algorithm='" + algorithm + '\'' + + ", kid='" + kid + '\'' + + '}'; } + + // Override equals and hashCode to ensure proper functioning of the HashMap } \ No newline at end of file diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java b/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java index 0f9cbd3d..3f7973a6 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java @@ -1,5 +1,8 @@ package io.opentdf.platform.sdk; +import io.opentdf.platform.policy.Algorithm; +import io.opentdf.platform.policy.KasPublicKeyAlgEnum; + import javax.annotation.Nonnull; import static io.opentdf.platform.sdk.NanoTDFType.ECCurve.SECP256R1; @@ -46,7 +49,43 @@ public static KeyType fromString(String keyType) { throw new IllegalArgumentException("No enum constant for key type: " + keyType); } + public static KeyType fromAlgorithm(Algorithm algorithm) { + if (algorithm == null) { + throw new IllegalArgumentException("Algorithm cannot be null"); + } + switch (algorithm) { + case ALGORITHM_RSA_2048: + return KeyType.RSA2048Key; + case ALGORITHM_EC_P256: + return KeyType.EC256Key; + case ALGORITHM_EC_P384: + return KeyType.EC384Key; + case ALGORITHM_EC_P521: + return KeyType.EC521Key; + default: + throw new IllegalArgumentException("Unsupported algorithm: " + algorithm); + } + } + + public static KeyType fromPublicKeyAlgorithm(KasPublicKeyAlgEnum algorithm) { + if (algorithm == null) { + throw new IllegalArgumentException("Algorithm cannot be null"); + } + switch (algorithm) { + case KAS_PUBLIC_KEY_ALG_ENUM_RSA_2048: + return KeyType.RSA2048Key; + case KAS_PUBLIC_KEY_ALG_ENUM_EC_SECP256R1: + return KeyType.EC256Key; + case KAS_PUBLIC_KEY_ALG_ENUM_EC_SECP384R1: + return KeyType.EC384Key; + case KAS_PUBLIC_KEY_ALG_ENUM_EC_SECP521R1: + return KeyType.EC521Key; + default: + throw new IllegalArgumentException("Unsupported algorithm: " + algorithm); + } + } + public boolean isEc() { return this.curve != null; } -} \ No newline at end of file +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/Planner.java b/sdk/src/main/java/io/opentdf/platform/sdk/Planner.java new file mode 100644 index 00000000..b07c735c --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/Planner.java @@ -0,0 +1,190 @@ +package io.opentdf.platform.sdk; + +import com.connectrpc.ConnectException; +import com.google.gson.Gson; +import com.google.gson.JsonSyntaxException; +import com.google.gson.annotations.SerializedName; +import io.opentdf.platform.policy.Algorithm; +import io.opentdf.platform.policy.SimpleKasKey; +import io.opentdf.platform.policy.SimpleKasPublicKey; +import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationRequest; +import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse; +import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.UUID; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + + +public class Planner { + private static final String BASE_KEY = "base_key"; + private final Config.TDFConfig tdfConfig; + private final SDK.Services services; + private final BiFunction granterFactory; + + + private static final Logger logger = LoggerFactory.getLogger(Planner.class); + + Planner(Config.TDFConfig config, SDK.Services services, BiFunction granterFactory) { + this.tdfConfig = Objects.requireNonNull(config); + this.services = Objects.requireNonNull(services); + this.granterFactory = granterFactory; + } + + private static String getUUID() { + return UUID.randomUUID().toString(); + } + + Map> getSplits() { + List splitPlan; + if (tdfConfig.autoconfigure) { + if (tdfConfig.splitPlan != null && !tdfConfig.splitPlan.isEmpty()) { + throw new IllegalArgumentException("cannot use autoconfigure with a split plan provided in the TDFConfig"); + } + splitPlan = getAutoconfigurePlan(services, tdfConfig); + } else if (tdfConfig.splitPlan == null || tdfConfig.splitPlan.isEmpty()) { + splitPlan = generatePlanFromProvidedKases(tdfConfig.kasInfoList); + } else { + splitPlan = tdfConfig.splitPlan.stream() + .map(k -> new Autoconfigure.KeySplitTemplate(k.kas, k.splitID, null, null)) + .collect(Collectors.toList()); + } + + if (splitPlan.isEmpty()) { + throw new SDK.KasInfoMissing("no plan was constructed via autoconfigure, explicit split plan or provided kases"); + } + return resolveKeys(splitPlan); + } + + private List getAutoconfigurePlan(SDK.Services services, Config.TDFConfig tdfConfig) { + Autoconfigure.Granter granter = granterFactory.apply(services, tdfConfig); + return granter.getSplits(defaultKases(tdfConfig), Planner::getUUID, () -> Planner.fetchBaseKey(services.wellknown())); + } + + + List generatePlanFromProvidedKases(List kases) { + if (kases.size() == 1) { + var kasInfo = kases.get(0); + return Collections.singletonList(new Autoconfigure.KeySplitTemplate(kasInfo.URL, "", kasInfo.KID, null)); + } + List splitPlan = new ArrayList<>(); + for (var kasInfo : kases) { + var keyType = kasInfo.Algorithm == null ? null : KeyType.fromString(kasInfo.Algorithm); + splitPlan.add(new Autoconfigure.KeySplitTemplate(kasInfo.URL, getUUID(), kasInfo.KID, keyType)); + } + return splitPlan; + } + + static Optional fetchBaseKey(WellKnownServiceClientInterface wellknown) { + var responseMessage = wellknown + .getWellKnownConfigurationBlocking(GetWellKnownConfigurationRequest.getDefaultInstance(), Collections.emptyMap()) + .execute(); + GetWellKnownConfigurationResponse response; + try { + response = RequestHelper.getOrThrow(responseMessage); + } catch (ConnectException e) { + throw new SDKException("unable to retrieve base key from well known endpoint", e); + } + + String baseKeyJson; + try { + baseKeyJson = response + .getConfiguration() + .getFieldsOrThrow(BASE_KEY) + .getStringValue(); + } catch (IllegalArgumentException e) { + logger.info( "no `" + BASE_KEY + "` found in well known configuration.", e); + return Optional.empty(); + } + + BaseKey baseKey; + try { + baseKey = gson.fromJson(baseKeyJson, BaseKey.class); + } catch (JsonSyntaxException e) { + throw new SDKException("base key in well known configuration is malformed [" + baseKeyJson + "]", e); + } + + if (baseKey == null || baseKey.kasUrl == null || baseKey.publicKey == null || baseKey.publicKey.kid == null || baseKey.publicKey.pem == null || baseKey.publicKey.algorithm == null) { + logger.error("base key in well known configuration is missing required fields [{}]. base key will not be used", baseKeyJson); + return Optional.empty(); + } + + return Optional.of(SimpleKasKey.newBuilder() + .setKasUri(baseKey.kasUrl) + .setPublicKey( + SimpleKasPublicKey.newBuilder() + .setKid(baseKey.publicKey.kid) + .setAlgorithm(baseKey.publicKey.algorithm) + .setPem(baseKey.publicKey.pem) + .build()) + .build()); + } + + private static final Gson gson = new Gson(); + + private static class BaseKey { + @SerializedName("kas_url") + String kasUrl; + + @SerializedName("public_key") + Key publicKey; + + private static class Key { + String kid; + String pem; + Algorithm algorithm; + } + } + + Map> resolveKeys(List splitPlan) { + Map> conjunction = new HashMap<>(); + var latestKASInfo = new HashMap(); + // Seed anything passed in manually + for (Config.KASInfo kasInfo : tdfConfig.kasInfoList) { + if (kasInfo.PublicKey != null && !kasInfo.PublicKey.isEmpty()) { + latestKASInfo.put(kasInfo.URL, kasInfo); + } + } + + for (var splitInfo: splitPlan) { + // Public key was passed in with kasInfoList + // TODO First look up in attribute information / add to split plan? + Config.KASInfo ki = latestKASInfo.get(splitInfo.kas); + if (ki == null || ki.PublicKey == null || ki.PublicKey.isBlank() || (splitInfo.kid != null && !splitInfo.kid.equals(ki.KID))) { + logger.info("no public key provided for KAS at {}, retrieving", splitInfo.kas); + var getKI = new Config.KASInfo(); + getKI.URL = splitInfo.kas; + getKI.Algorithm = splitInfo.keyType == null + ? (tdfConfig.wrappingKeyType == null ? null : tdfConfig.wrappingKeyType.toString()) + : splitInfo.keyType.toString(); + ki = services.kas().getPublicKey(getKI); + latestKASInfo.put(splitInfo.kas, ki); + } + conjunction.computeIfAbsent(splitInfo.splitID, s -> new ArrayList<>()).add(ki); + } + return conjunction; + } + + static List defaultKases(Config.TDFConfig config) { + List allk = new ArrayList<>(); + List defk = new ArrayList<>(); + + for (Config.KASInfo kasInfo : config.kasInfoList) { + if (kasInfo.Default != null && kasInfo.Default) { + defk.add(kasInfo.URL); + } else if (defk.isEmpty()) { + allk.add(kasInfo.URL); + } + } + return defk.isEmpty() ? allk : defk; + } +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java b/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java index 74468ecf..f6009ed6 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java @@ -9,6 +9,7 @@ import io.opentdf.platform.policy.namespaces.NamespaceServiceClientInterface; import io.opentdf.platform.policy.resourcemapping.ResourceMappingServiceClientInterface; import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceClientInterface; +import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface; import javax.net.ssl.TrustManager; import java.io.IOException; @@ -75,6 +76,8 @@ public interface Services extends AutoCloseable { KeyAccessServerRegistryServiceClientInterface kasRegistry(); + WellKnownServiceClientInterface wellknown(); + KAS kas(); } diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java b/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java index dd83f75a..03b1943c 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java @@ -33,6 +33,7 @@ import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationRequest; import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse; import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClient; +import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface; import nl.altindag.ssl.SSLFactory; import nl.altindag.ssl.pem.util.PemUtils; import okhttp3.OkHttpClient; @@ -251,6 +252,7 @@ ServicesAndInternals buildServices() { var resourceMappingService = new ResourceMappingServiceClient(client); var authorizationService = new AuthorizationServiceClient(client); var kasRegistryService = new KeyAccessServerRegistryServiceClient(client); + var wellKnownService = new WellKnownServiceClient(client); var services = new SDK.Services() { @Override @@ -290,6 +292,11 @@ public KeyAccessServerRegistryServiceClient kasRegistry() { return kasRegistryService; } + @Override + public WellKnownServiceClientInterface wellknown() { + return wellKnownService; + } + @Override public SDK.KAS kas() { return kasClient; diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java b/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java index 4b605c15..80522d38 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java @@ -5,11 +5,8 @@ import com.google.gson.GsonBuilder; import com.nimbusds.jose.*; -import io.opentdf.platform.policy.Value; import io.opentdf.platform.policy.kasregistry.ListKeyAccessServersRequest; import io.opentdf.platform.policy.kasregistry.ListKeyAccessServersResponse; -import io.opentdf.platform.sdk.Config.TDFConfig; -import io.opentdf.platform.sdk.Autoconfigure.AttributeValueFQN; import io.opentdf.platform.sdk.Config.KASInfo; import org.apache.commons.codec.DecoderException; @@ -141,7 +138,7 @@ private PolicyObject createPolicyObject(List at private static final Base64.Encoder encoder = Base64.getEncoder(); - private void prepareManifest(Config.TDFConfig tdfConfig, SDK.KAS kas) { + private void prepareManifest(Config.TDFConfig tdfConfig, Map> splits) { manifest.tdfVersion = tdfConfig.renderVersionInfoInManifest ? TDF_VERSION : null; manifest.encryptionInformation.keyAccessType = kSplitKeyType; manifest.encryptionInformation.keyAccessObj = new ArrayList<>(); @@ -149,60 +146,12 @@ private void prepareManifest(Config.TDFConfig tdfConfig, SDK.KAS kas) { PolicyObject policyObject = createPolicyObject(tdfConfig.attributes); String base64PolicyObject = encoder .encodeToString(gson.toJson(policyObject).getBytes(StandardCharsets.UTF_8)); - Map latestKASInfo = new HashMap<>(); - if (tdfConfig.splitPlan == null || tdfConfig.splitPlan.isEmpty()) { - // Default split plan: Split keys across all KASes - List splitPlan = new ArrayList<>(tdfConfig.kasInfoList.size()); - int i = 0; - for (Config.KASInfo kasInfo : tdfConfig.kasInfoList) { - Autoconfigure.KeySplitStep step = new Autoconfigure.KeySplitStep(kasInfo.URL, ""); - if (tdfConfig.kasInfoList.size() > 1) { - step.splitID = String.format("s-%d", i++); - } - splitPlan.add(step); - if (kasInfo.PublicKey != null && !kasInfo.PublicKey.isEmpty()) { - latestKASInfo.put(kasInfo.URL, kasInfo); - } - } - tdfConfig.splitPlan = splitPlan; - } - // Seed anything passed in manually - for (Config.KASInfo kasInfo : tdfConfig.kasInfoList) { - if (kasInfo.PublicKey != null && !kasInfo.PublicKey.isEmpty()) { - latestKASInfo.put(kasInfo.URL, kasInfo); - } - } - // split plan: restructure by conjunctions - Map> conjunction = new HashMap<>(); - List splitIDs = new ArrayList<>(); - - for (Autoconfigure.KeySplitStep splitInfo : tdfConfig.splitPlan) { - // Public key was passed in with kasInfoList - // TODO First look up in attribute information / add to split plan? - Config.KASInfo ki = latestKASInfo.get(splitInfo.kas); - if (ki == null || ki.PublicKey == null || ki.PublicKey.isBlank()) { - logger.info("no public key provided for KAS at {}, retrieving", splitInfo.kas); - var getKI = new Config.KASInfo(); - getKI.URL = splitInfo.kas; - getKI.Algorithm = tdfConfig.wrappingKeyType.toString(); - getKI = kas.getPublicKey(getKI); - latestKASInfo.put(splitInfo.kas, getKI); - ki = getKI; - } - if (conjunction.containsKey(splitInfo.splitID)) { - conjunction.get(splitInfo.splitID).add(ki); - } else { - List newList = new ArrayList<>(); - newList.add(ki); - conjunction.put(splitInfo.splitID, newList); - splitIDs.add(splitInfo.splitID); - } - } + List symKeys = new ArrayList<>(splits.size()); + for (var split : splits.entrySet()) { + String splitID = split.getKey(); - List symKeys = new ArrayList<>(splitIDs.size()); - for (String splitID : splitIDs) { // Symmetric key byte[] symKey = new byte[GCM_KEY_SIZE]; sRandom.nextBytes(symKey); @@ -229,7 +178,8 @@ private void prepareManifest(Config.TDFConfig tdfConfig, SDK.KAS kas) { encryptedMetadata = encoder.encodeToString(metadata.getBytes(StandardCharsets.UTF_8)); } - for (Config.KASInfo kasInfo : conjunction.get(splitID)) { + List kasInfos = split.getValue(); + for (Config.KASInfo kasInfo : kasInfos) { if (kasInfo.PublicKey == null || kasInfo.PublicKey.isEmpty()) { throw new SDK.KasPublicKeyMissing("Kas public key is missing in kas information list"); } @@ -263,7 +213,11 @@ private Manifest.KeyAccess createKeyAccess(Config.TDFConfig tdfConfig, Config.KA keyAccess.sid = splitID; keyAccess.schemaVersion = KEY_ACCESS_SCHEMA_VERSION; - if (tdfConfig.wrappingKeyType.isEc()) { + var algorithm = kasInfo.Algorithm == null || kasInfo.Algorithm.isEmpty() + ? tdfConfig.wrappingKeyType.toString() + : kasInfo.Algorithm; + + if (KeyType.fromString(algorithm).isEc()) { var ecKeyWrappedKeyInfo = createECWrappedKey(tdfConfig, kasInfo, symKey); keyAccess.wrappedKey = ecKeyWrappedKeyInfo.wrappedKey; keyAccess.ephemeralPublicKey = ecKeyWrappedKeyInfo.publicKey; @@ -300,7 +254,6 @@ private String createRSAWrappedKey(Config.KASInfo kasInfo, byte[] symKey) { } } - private static final Base64.Decoder decoder = Base64.getDecoder(); public static class Reader { @@ -325,7 +278,6 @@ public Manifest getManifest() { this.aesGcm = new AesGcm(payloadKey); this.payloadKey = payloadKey; this.unencryptedMetadata = unencryptedMetadata; - } public void readPayload(OutputStream outputStream) throws SDK.SegmentSignatureMismatch, IOException { @@ -399,35 +351,11 @@ private static byte[] calculateSignature(byte[] data, byte[] secret, Config.Inte } TDFObject createTDF(InputStream payload, OutputStream outputStream, Config.TDFConfig tdfConfig) throws SDKException, IOException { - - if (tdfConfig.autoconfigure) { - Autoconfigure.Granter granter = new Autoconfigure.Granter(new ArrayList<>()); - if (tdfConfig.attributeValues != null && !tdfConfig.attributeValues.isEmpty()) { - granter = Autoconfigure.newGranterFromAttributes(tdfConfig.attributeValues.toArray(new Value[0])); - } else if (tdfConfig.attributes != null && !tdfConfig.attributes.isEmpty()) { - granter = Autoconfigure.newGranterFromService(services.attributes(), services.kas().getKeyCache(), - tdfConfig.attributes.toArray(new AttributeValueFQN[0])); - } - - if (granter == null) { - throw new AutoConfigureException("Failed to create Granter"); // Replace with appropriate error handling - } - - List dk = defaultKases(tdfConfig); - tdfConfig.splitPlan = granter.plan(dk, () -> UUID.randomUUID().toString()); - - if (tdfConfig.splitPlan == null) { - throw new AutoConfigureException("Failed to generate Split Plan"); // Replace with appropriate error - // handling - } - } - - if (tdfConfig.kasInfoList.isEmpty() && (tdfConfig.splitPlan == null || tdfConfig.splitPlan.isEmpty())) { - throw new SDK.KasInfoMissing("kas information is missing, no key access template specified or inferred"); - } + Planner planner = new Planner(tdfConfig, services, Autoconfigure::createGranter); + Map> splits = planner.getSplits(); TDFObject tdfObject = new TDFObject(); - tdfObject.prepareManifest(tdfConfig, services.kas()); + tdfObject.prepareManifest(tdfConfig, splits); long encryptedSegmentSize = tdfConfig.defaultSegmentSize + kGcmIvSize + kAesBlockSize; TDFWriter tdfWriter = new TDFWriter(outputStream); @@ -560,22 +488,6 @@ TDFObject createTDF(InputStream payload, OutputStream outputStream, Config.TDFCo return tdfObject; } - static List defaultKases(TDFConfig config) { - List allk = new ArrayList<>(); - List defk = new ArrayList<>(); - - for (KASInfo kasInfo : config.kasInfoList) { - if (kasInfo.Default != null && kasInfo.Default) { - defk.add(kasInfo.URL); - } else if (defk.isEmpty()) { - allk.add(kasInfo.URL); - } - } - if (defk.isEmpty()) { - return allk; - } - return defk; - } Reader loadTDF(SeekableByteChannel tdf, String platformUrl) throws SDKException, IOException { return loadTDF(tdf, Config.newTDFReaderConfig(), platformUrl); diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/AutoconfigureTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/AutoconfigureTest.java index 59cd0912..53fe1c06 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/AutoconfigureTest.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/AutoconfigureTest.java @@ -1,18 +1,8 @@ package io.opentdf.platform.sdk; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - import com.connectrpc.ResponseMessage; import com.connectrpc.UnaryBlockingCall; +import io.opentdf.platform.policy.Algorithm; import io.opentdf.platform.policy.Attribute; import io.opentdf.platform.policy.AttributeRuleTypeEnum; import io.opentdf.platform.policy.KasPublicKey; @@ -21,16 +11,17 @@ import io.opentdf.platform.policy.KeyAccessServer; import io.opentdf.platform.policy.Namespace; import io.opentdf.platform.policy.PublicKey; +import io.opentdf.platform.policy.SimpleKasKey; +import io.opentdf.platform.policy.SimpleKasPublicKey; import io.opentdf.platform.policy.Value; import io.opentdf.platform.policy.attributes.AttributesServiceClient; import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsRequest; import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsResponse; import io.opentdf.platform.sdk.Autoconfigure.AttributeValueFQN; +import io.opentdf.platform.sdk.Autoconfigure.Granter; import io.opentdf.platform.sdk.Autoconfigure.Granter.AttributeBooleanExpression; import io.opentdf.platform.sdk.Autoconfigure.Granter.BooleanKeyExpression; import io.opentdf.platform.sdk.Autoconfigure.KeySplitStep; -import io.opentdf.platform.sdk.Autoconfigure.Granter; - import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.mockito.Mockito; @@ -39,10 +30,27 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Objects; +import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; +import java.util.function.Supplier; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class AutoconfigureTest { @@ -56,6 +64,16 @@ public class AutoconfigureTest { public static final String SPECIFIED_KAS = "https://attr.kas.com/"; public static final String EVEN_MORE_SPECIFIC_KAS = "https://value.kas.com/"; private static final String NAMESPACE_KAS = "https://namespace.kas.com/"; + private static final SimpleKasKey NAMESPACE_KAS_KEY = SimpleKasKey.newBuilder().setKasUri("https://mapped.example.com").setKasId("mapped").setPublicKey( + SimpleKasPublicKey.newBuilder().setAlgorithm(Algorithm.ALGORITHM_EC_P521).setPem("namespacekey").setKid("namespacekeykid").build() + ).build(); + private static final SimpleKasKey ATTRIBUTE_KEY = SimpleKasKey.newBuilder().setKasUri("https://mapped.example.com").setKasId("mapped").setPublicKey( + SimpleKasPublicKey.newBuilder().setAlgorithm(Algorithm.ALGORITHM_EC_P521).setPem("attrpem").setKid("attrkeykid").build() + ).build(); + private static final SimpleKasKey VALUE_KEY = SimpleKasKey.newBuilder().setKasUri("https://mapped.example.com").setKasId("mapped").setPublicKey( + SimpleKasPublicKey.newBuilder().setAlgorithm(Algorithm.ALGORITHM_EC_P521).setPem("valuepem").setKid("valuekeykid").build() + ).build(); + private static Autoconfigure.AttributeNameFQN UNMAPPED; private static Autoconfigure.AttributeNameFQN SPKSPECKED; private static Autoconfigure.AttributeNameFQN SPKUNSPECKED; @@ -64,6 +82,8 @@ public class AutoconfigureTest { private static Autoconfigure.AttributeNameFQN REL; private static Autoconfigure.AttributeNameFQN UNSPECKED; private static Autoconfigure.AttributeNameFQN SPECKED; + private static Autoconfigure.AttributeNameFQN MAPPED; + private static Autoconfigure.AttributeNameFQN SPKMAPPED; private static Autoconfigure.AttributeValueFQN clsA; private static Autoconfigure.AttributeValueFQN clsS; @@ -84,6 +104,9 @@ public class AutoconfigureTest { private static Autoconfigure.AttributeValueFQN spk2spk2uns; private static Autoconfigure.AttributeValueFQN spk2spk2spk; + private static Autoconfigure.AttributeValueFQN mp2uns2uns; + private static Autoconfigure.AttributeValueFQN mp2uns2mp; + @BeforeAll public static void setup() throws AutoConfigureException { // Initialize the FQNs (Fully Qualified Names) @@ -92,8 +115,11 @@ public static void setup() throws AutoConfigureException { REL = new Autoconfigure.AttributeNameFQN("https://virtru.com/attr/Releasable%20To"); UNSPECKED = new Autoconfigure.AttributeNameFQN("https://other.com/attr/unspecified"); SPECKED = new Autoconfigure.AttributeNameFQN("https://other.com/attr/specified"); + MAPPED = new Autoconfigure.AttributeNameFQN("https://other.com/attr/mapped"); + UNMAPPED = new Autoconfigure.AttributeNameFQN("https://mapped.com/attr/unspecified"); SPKUNSPECKED = new Autoconfigure.AttributeNameFQN("https://hasgrants.com/attr/unspecified"); SPKSPECKED = new Autoconfigure.AttributeNameFQN("https://hasgrants.com/attr/specified"); + SPKMAPPED = new Autoconfigure.AttributeNameFQN("https://hasgrants.com/attr/mapped"); clsA = new Autoconfigure.AttributeValueFQN("https://virtru.com/attr/Classification/value/Allowed"); clsS = new Autoconfigure.AttributeValueFQN("https://virtru.com/attr/Classification/value/Secret"); @@ -117,6 +143,9 @@ public static void setup() throws AutoConfigureException { spk2uns2spk = new Autoconfigure.AttributeValueFQN("https://hasgrants.com/attr/unspecified/value/specked"); spk2spk2uns = new Autoconfigure.AttributeValueFQN("https://hasgrants.com/attr/specified/value/unspecked"); spk2spk2spk = new Autoconfigure.AttributeValueFQN("https://hasgrants.com/attr/specified/value/specked"); + + mp2uns2uns = new Autoconfigure.AttributeValueFQN("https://mapped.com/attr/unspecified/value/unspecked"); + mp2uns2mp = new Autoconfigure.AttributeValueFQN("https://mapped.com/attr/unspecified/value/mapped"); } private static String spongeCase(String s) { @@ -181,6 +210,7 @@ private Attribute mockAttributeFor(Autoconfigure.AttributeNameFQN fqn) { Namespace ns1 = Namespace.newBuilder().setId("v").setName("virtru.com").setFqn("https://virtru.com").build(); Namespace ns2 = Namespace.newBuilder().setId("o").setName("other.com").setFqn("https://other.com").build(); Namespace ns3 = Namespace.newBuilder().setId("h").setName("hasgrants.com").addGrants(KeyAccessServer.newBuilder().setUri(NAMESPACE_KAS).build()).setFqn("https://hasgrants.com").build(); + Namespace ns4 = Namespace.newBuilder().setId("m").setName("mapped.com").addKasKeys(NAMESPACE_KAS_KEY).build(); String key = fqn.getKey(); if (key.equals(CLS.getKey())) { @@ -196,10 +226,15 @@ private Attribute mockAttributeFor(Autoconfigure.AttributeNameFQN fqn) { .setName("Releasable To").setRule(AttributeRuleTypeEnum.ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF) .setFqn(fqn.toString()).build(); } else if (key.equals(SPECKED.getKey())) { + var kasKey = SimpleKasKey.newBuilder() + .setKasUri(SPECKED.getKey()) + .setPublicKey(SimpleKasPublicKey.newBuilder().setPem("speckedpem") + .setAlgorithm(Algorithm.ALGORITHM_EC_P521) + .setKid("speckedkeykid")).build(); return Attribute.newBuilder().setId("SPK").setNamespace(ns2) .setName("specified").setRule(AttributeRuleTypeEnum.ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF) .setFqn(fqn.toString()) - .addGrants(KeyAccessServer.newBuilder().setUri(SPECIFIED_KAS).build()) + .addGrants(KeyAccessServer.newBuilder().setUri(SPECIFIED_KAS).addKasKeys(kasKey)) .build(); } else if (key.equals(UNSPECKED.getKey())) { return Attribute.newBuilder().setId("UNS").setNamespace(ns2) @@ -216,6 +251,17 @@ private Attribute mockAttributeFor(Autoconfigure.AttributeNameFQN fqn) { .setName("unspecified").setRule(AttributeRuleTypeEnum.ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF) .setName(fqn.toString()) .build(); + } else if (key.equals(MAPPED.getKey())) { + return Attribute.newBuilder().setId(MAPPED.getKey()).setNamespace(ns4) + .setName("mapped attribute").setRule(AttributeRuleTypeEnum.ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF) + .setKasKeys(0, ATTRIBUTE_KEY) + .setName(fqn.toString()) + .build(); + } else if (key.equals(UNMAPPED.getKey())) { + return Attribute.newBuilder().setId(UNMAPPED.getKey()).setNamespace(ns4) + .setName("unmapped attribute").setRule(AttributeRuleTypeEnum.ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF) + .setName(fqn.toString()) + .build(); } throw new IllegalArgumentException("Key not recognized: " + key); @@ -287,7 +333,13 @@ private Value mockValueFor(Autoconfigure.AttributeValueFQN fqn) throws AutoConfi p = p.toBuilder().addGrants(KeyAccessServer.newBuilder().setUri(EVEN_MORE_SPECIFIC_KAS).build()) .build(); } + } else if (Objects.equals(UNMAPPED.getKey(), an.getKey())) { + if (fqn.value().equalsIgnoreCase("mapped")) { + p = p.toBuilder().addKasKeys(VALUE_KEY) + .build(); + } } + return p; } @@ -368,7 +420,7 @@ public void testConfigurationServicePutGet() { for (ConfigurationTestCase tc : testCases) { assertDoesNotThrow(() -> { List v = valuesToPolicy(tc.getPolicy().toArray(new AttributeValueFQN[0])); - Granter grants = Autoconfigure.newGranterFromAttributes(v.toArray(new Value[0])); + Granter grants = Autoconfigure.newGranterFromAttributes(null, v.toArray(new Value[0])); assertThat(grants).isNotNull(); assertThat(grants.getGrants()).hasSize(tc.getSize()); assertThat(policyToStringKeys(tc.getPolicy())).containsAll(grants.getGrants().keySet()); @@ -396,7 +448,7 @@ public void testReasonerConstructAttributeBoolean() { "https://virtru.com/attr/Classification/value/Secret&https://virtru.com/attr/Releasable%20To/value/CAN", "[DEFAULT]&(https://kas.ca/)", "(https://kas.ca/)", - List.of(new KeySplitStep(KAS_CA, ""))), + List.of(new Autoconfigure.KeySplitTemplate(KAS_CA, "", null, null))), new ReasonerTestCase( "one defaulted attr", List.of(clsS), @@ -404,7 +456,7 @@ public void testReasonerConstructAttributeBoolean() { "https://virtru.com/attr/Classification/value/Secret", "[DEFAULT]", "", - List.of(new KeySplitStep(KAS_US, ""))), + List.of(new Autoconfigure.KeySplitTemplate(KAS_US, "", null, null))), new ReasonerTestCase( "empty policy", List.of(), @@ -412,7 +464,7 @@ public void testReasonerConstructAttributeBoolean() { "∅", "", "", - List.of(new KeySplitStep(KAS_US, ""))), + List.of(new Autoconfigure.KeySplitTemplate(KAS_US, "", null, null))), new ReasonerTestCase( "old school splits", List.of(), @@ -420,8 +472,9 @@ public void testReasonerConstructAttributeBoolean() { "∅", "", "", - List.of(new KeySplitStep(KAS_AU, "1"), new KeySplitStep(KAS_CA, "2"), - new KeySplitStep(KAS_US, "3"))), + List.of(new Autoconfigure.KeySplitTemplate(KAS_AU, "1", null, null), + new Autoconfigure.KeySplitTemplate(KAS_CA, "2", null, null), + new Autoconfigure.KeySplitTemplate(KAS_US, "3", null, null))), new ReasonerTestCase( "simple with all three ops", List.of(clsS, rel2gbr, n2kInt), @@ -429,7 +482,7 @@ public void testReasonerConstructAttributeBoolean() { "https://virtru.com/attr/Classification/value/Secret&https://virtru.com/attr/Releasable%20To/value/GBR&https://virtru.com/attr/Need%20to%20Know/value/INT", "[DEFAULT]&(https://kas.uk/)&(https://kas.uk/)", "(https://kas.uk/)", - List.of(new KeySplitStep(KAS_UK, ""))), + List.of(new Autoconfigure.KeySplitTemplate(KAS_UK, "", null, null))), new ReasonerTestCase( "compartments", List.of(clsS, rel2gbr, rel2usa, n2kHCS, n2kSI), @@ -437,8 +490,10 @@ public void testReasonerConstructAttributeBoolean() { "https://virtru.com/attr/Classification/value/Secret&https://virtru.com/attr/Releasable%20To/value/{GBR,USA}&https://virtru.com/attr/Need%20to%20Know/value/{HCS,SI}", "[DEFAULT]&(https://kas.uk/⋁https://kas.us/)&(https://hcs.kas.us/⋀https://si.kas.us/)", "(https://kas.uk/⋁https://kas.us/)&(https://hcs.kas.us/)&(https://si.kas.us/)", - List.of(new KeySplitStep(KAS_UK, "1"), new KeySplitStep(KAS_US, "1"), - new KeySplitStep(KAS_US_HCS, "2"), new KeySplitStep(KAS_US_SA, "3"))), + List.of(new Autoconfigure.KeySplitTemplate(KAS_UK, "1", null, null), + new Autoconfigure.KeySplitTemplate(KAS_US, "1", null, null), + new Autoconfigure.KeySplitTemplate(KAS_US_HCS, "2", null, null), + new Autoconfigure.KeySplitTemplate(KAS_US_SA, "3", null, null))), new ReasonerTestCase( "compartments - case insensitive", List.of( @@ -447,11 +502,13 @@ public void testReasonerConstructAttributeBoolean() { "https://virtru.com/attr/Classification/value/Secret&https://virtru.com/attr/Releasable%20To/value/{GBR,USA}&https://virtru.com/attr/Need%20to%20Know/value/{HCS,SI}", "[DEFAULT]&(https://kas.uk/⋁https://kas.us/)&(https://hcs.kas.us/⋀https://si.kas.us/)", "(https://kas.uk/⋁https://kas.us/)&(https://hcs.kas.us/)&(https://si.kas.us/)", - List.of(new KeySplitStep(KAS_UK, "1"), new KeySplitStep(KAS_US, "1"), - new KeySplitStep(KAS_US_HCS, "2"), new KeySplitStep(KAS_US_SA, "3")))); + List.of(new Autoconfigure.KeySplitTemplate(KAS_UK, "1", null, null), + new Autoconfigure.KeySplitTemplate(KAS_US, "1", null, null), + new Autoconfigure.KeySplitTemplate(KAS_US_HCS, "2", null, null), + new Autoconfigure.KeySplitTemplate(KAS_US_SA, "3", null, null)))); for (ReasonerTestCase tc : testCases) { - Granter reasoner = Autoconfigure.newGranterFromAttributes( + Granter reasoner = Autoconfigure.newGranterFromAttributes(null, valuesToPolicy(tc.getPolicy().toArray(new AttributeValueFQN[0])).toArray(new Value[0])); assertThat(reasoner).isNotNull(); @@ -467,15 +524,33 @@ public void testReasonerConstructAttributeBoolean() { var wrapper = new Object() { int i = 0; }; - List plan = reasoner.plan(tc.getDefaults(), () -> { - return String.valueOf(wrapper.i++ + 1); - } - - ); - assertThat(plan).isEqualTo(tc.getPlan()); + List plan = reasoner.getSplits(tc.getDefaults(), () -> String.valueOf(wrapper.i++ + 1), Optional::empty); + assertThat(plan) + .as(tc.name) + .isEqualTo(tc.getPlan()); } } + @Test + void testUsingAttributeMappedAtNamespace() { + Granter granter = Autoconfigure.newGranterFromAttributes(new KASKeyCache(), mockValueFor(mp2uns2uns)); + var counter = new AtomicInteger(0); + var splitPlan = granter.getSplits(Collections.emptyList(), () -> Integer.toString(counter.getAndIncrement()), Optional::empty); + assertThat(splitPlan).isEqualTo(List.of(new Autoconfigure.KeySplitTemplate("https://mapped.example.com", "", NAMESPACE_KAS_KEY.getPublicKey().getKid(), KeyType.EC521Key))); + } + + @Test + void testUsingAttributeMappedAtMultiplePlaces() { + var attributes = new Value[]{mockValueFor(mp2uns2uns), mockValueFor(mp2uns2mp)}; + Granter granter = Autoconfigure.newGranterFromAttributes(new KASKeyCache(), attributes); + var counter = new AtomicInteger(0); + var splitPlan = granter.getSplits(Collections.emptyList(), () -> Integer.toString(counter.getAndIncrement()), Optional::empty); + assertThat(splitPlan).isEqualTo(List.of( + new Autoconfigure.KeySplitTemplate(NAMESPACE_KAS_KEY.getKasUri(), "0", NAMESPACE_KAS_KEY.getPublicKey().getKid(), KeyType.EC521Key), + new Autoconfigure.KeySplitTemplate(VALUE_KEY.getKasUri(), "0", VALUE_KEY.getPublicKey().getKid(), KeyType.EC521Key) + )); + } + GetAttributeValuesByFqnsResponse getResponse(GetAttributeValuesByFqnsRequest req) { GetAttributeValuesByFqnsResponse.Builder builder = GetAttributeValuesByFqnsResponse.newBuilder(); @@ -505,72 +580,77 @@ public void testReasonerSpecificity() { "uns.uns => default", List.of(uns2uns), List.of(KAS_US), - List.of(new KeySplitStep(KAS_US, ""))), + List.of(new Autoconfigure.KeySplitTemplate(KAS_US, "", null, null))), new ReasonerTestCase( "uns.spk => spk", List.of(uns2spk), List.of(KAS_US), - List.of(new KeySplitStep(EVEN_MORE_SPECIFIC_KAS, ""))), + List.of(new Autoconfigure.KeySplitTemplate(EVEN_MORE_SPECIFIC_KAS, "", null, null))), new ReasonerTestCase( "spk.uns => spk", List.of(spk2uns), List.of(KAS_US), - List.of(new KeySplitStep(SPECIFIED_KAS, ""))), + List.of(new Autoconfigure.KeySplitTemplate(SPECIFIED_KAS, "", null, null))), new ReasonerTestCase( "spk.spk => value.spk", List.of(spk2spk), List.of(KAS_US), - List.of(new KeySplitStep(EVEN_MORE_SPECIFIC_KAS, ""))), + List.of(new Autoconfigure.KeySplitTemplate(EVEN_MORE_SPECIFIC_KAS, "", null, null))), new ReasonerTestCase( "spk.spk & spk.uns => value.spk || attr.spk", List.of(spk2spk, spk2uns), List.of(KAS_US), - List.of(new KeySplitStep(EVEN_MORE_SPECIFIC_KAS, "1"), new KeySplitStep(SPECIFIED_KAS, "1"))), + List.of(new Autoconfigure.KeySplitTemplate(EVEN_MORE_SPECIFIC_KAS, "1", null, null), + new Autoconfigure.KeySplitTemplate(SPECIFIED_KAS, "1", null, null))), new ReasonerTestCase( "spk.uns & spk.spk => value.spk || attr.spk", List.of(spk2uns, spk2spk), List.of(KAS_US), - List.of(new KeySplitStep(SPECIFIED_KAS, "1"), new KeySplitStep(EVEN_MORE_SPECIFIC_KAS, "1"))), + List.of(new Autoconfigure.KeySplitTemplate(SPECIFIED_KAS, "1", null, null), + new Autoconfigure.KeySplitTemplate(EVEN_MORE_SPECIFIC_KAS, "1", null, null))), new ReasonerTestCase( "uns.spk & spk.spk => value.spk", List.of(spk2spk, uns2spk), List.of(KAS_US), - List.of(new KeySplitStep(EVEN_MORE_SPECIFIC_KAS, ""))), + List.of(new Autoconfigure.KeySplitTemplate(EVEN_MORE_SPECIFIC_KAS, "", null, null))), new ReasonerTestCase( "uns.spk & uns.uns => spk", List.of(uns2spk, uns2uns), List.of(KAS_US), - List.of(new KeySplitStep(EVEN_MORE_SPECIFIC_KAS, ""))), + List.of(new Autoconfigure.KeySplitTemplate(EVEN_MORE_SPECIFIC_KAS, "", null, null))), new ReasonerTestCase( "uns.uns & uns.spk => spk", List.of(uns2uns, uns2spk), List.of(KAS_US), - List.of(new KeySplitStep(EVEN_MORE_SPECIFIC_KAS, ""))), + List.of(new Autoconfigure.KeySplitTemplate(EVEN_MORE_SPECIFIC_KAS, "", null, null))), new ReasonerTestCase( - "uns.uns & uns.spk => spk", + "uns.uns & spk.spk => spk", List.of(uns2uns, spk2spk), List.of(KAS_US), - List.of(new KeySplitStep(EVEN_MORE_SPECIFIC_KAS, ""))), + List.of(new Autoconfigure.KeySplitTemplate(EVEN_MORE_SPECIFIC_KAS, "", null, null))), new ReasonerTestCase( "spk.uns.uns => ns.spk", List.of(spk2uns2uns, uns2uns), List.of(KAS_US), - List.of(new KeySplitStep(NAMESPACE_KAS, ""))), + List.of(new Autoconfigure.KeySplitTemplate(NAMESPACE_KAS, "", null, null))), new ReasonerTestCase( "spk.uns.uns & uns.uns => ns.spk", List.of(spk2uns2uns, uns2uns), List.of(KAS_US), - List.of(new KeySplitStep(NAMESPACE_KAS, ""))), + List.of(new Autoconfigure.KeySplitTemplate(NAMESPACE_KAS, "", null, null))), new ReasonerTestCase( "spk.uns.uns & uns.spk => ns.spk && spk", List.of(spk2uns2uns, uns2spk), List.of(KAS_US), - List.of(new KeySplitStep(NAMESPACE_KAS, "1"), new KeySplitStep(EVEN_MORE_SPECIFIC_KAS, "2"))), + List.of(new Autoconfigure.KeySplitTemplate(NAMESPACE_KAS, "1", null, null), + new Autoconfigure.KeySplitTemplate(EVEN_MORE_SPECIFIC_KAS, "2", null, null))), new ReasonerTestCase( "spk.uns.uns & spk.spk.uns && spk.uns.spk => ns.spk || attr.spk || value.spk", List.of(spk2uns2uns, spk2spk2uns, spk2uns2spk), List.of(KAS_US), - List.of(new KeySplitStep(NAMESPACE_KAS, "1"), new KeySplitStep(EVEN_MORE_SPECIFIC_KAS, "1"), new KeySplitStep(SPECIFIED_KAS, "2"))) + List.of(new Autoconfigure.KeySplitTemplate(NAMESPACE_KAS, "1", null, null), + new Autoconfigure.KeySplitTemplate(EVEN_MORE_SPECIFIC_KAS, "1", null, null), + new Autoconfigure.KeySplitTemplate(SPECIFIED_KAS, "2", null, null))) ); for (ReasonerTestCase tc : testCases) { @@ -598,12 +678,11 @@ public void cancel() { var wrapper = new Object() { int i = 0; }; - List plan = reasoner.plan(tc.getDefaults(), () -> { - return String.valueOf(wrapper.i++ + 1); - } + List plan = reasoner.getSplits(tc.getDefaults(), () -> String.valueOf(wrapper.i++ + 1), Optional::empty); + assertThat(plan) + .as(tc.name) + .hasSameElementsAs(tc.getPlan()); - ); - assertThat(plan).hasSameElementsAs(tc.getPlan()); } } @@ -688,10 +767,10 @@ private static class ReasonerTestCase { private final String ats; private final String keyed; private final String reduced; - private final List plan; + private final List plan; ReasonerTestCase(String name, List policy, List defaults, String ats, String keyed, - String reduced, List plan) { + String reduced, List plan) { this.name = name; this.policy = policy; this.defaults = defaults; @@ -701,7 +780,7 @@ private static class ReasonerTestCase { this.plan = plan; } - ReasonerTestCase(String name, List policy, List defaults, List plan) { + ReasonerTestCase(String name, List policy, List defaults, List plan) { this.name = name; this.policy = policy; this.defaults = defaults; @@ -735,7 +814,7 @@ public String getReduced() { return reduced; } - public List getPlan() { + public List getPlan() { return plan; } } @@ -744,13 +823,11 @@ public List getPlan() { void testStoreKeysToCache_NoKeys() { KASKeyCache keyCache = Mockito.mock(KASKeyCache.class); KeyAccessServer kas1 = KeyAccessServer.newBuilder().setPublicKey( - PublicKey.newBuilder().setCached( - KasPublicKeySet.newBuilder())) + PublicKey.newBuilder().setCached( + KasPublicKeySet.newBuilder())) .build(); - List kases = List.of(kas1); - - Autoconfigure.storeKeysToCache(kases, keyCache); + Autoconfigure.storeKeysToCache(List.of(kas1), Collections.emptyList(), keyCache); verify(keyCache, never()).store(any(Config.KASInfo.class)); } @@ -780,14 +857,11 @@ void testStoreKeysToCache_WithKeys() { .setUri("https://example.com/kas") .build(); - // Add the KeyAccessServer to a list - List kases = List.of(kas1); - // Call the method under test - Autoconfigure.storeKeysToCache(kases, keyCache); + Autoconfigure.storeKeysToCache(List.of(kas1), Collections.emptyList(), keyCache); // Verify that the key was stored in the cache - Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1"); + Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1", "test-kid"); assertNotNull(storedKASInfo); assertEquals("https://example.com/kas", storedKASInfo.URL); assertEquals("test-kid", storedKASInfo.KID); @@ -826,21 +900,18 @@ void testStoreKeysToCache_MultipleKasEntries() { .setUri("https://example.com/kas") .build(); - // Add the KeyAccessServer to a list - List kases = List.of(kas1); - // Call the method under test - Autoconfigure.storeKeysToCache(kases, keyCache); + Autoconfigure.storeKeysToCache(List.of(kas1), Collections.emptyList(), keyCache); // Verify that the key was stored in the cache - Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1"); + Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1", "test-kid"); assertNotNull(storedKASInfo); assertEquals("https://example.com/kas", storedKASInfo.URL); assertEquals("test-kid", storedKASInfo.KID); assertEquals("ec:secp256r1", storedKASInfo.Algorithm); assertEquals("public-key-pem", storedKASInfo.PublicKey); - Config.KASInfo storedKASInfo2 = keyCache.get("https://example.com/kas", "rsa:2048"); + Config.KASInfo storedKASInfo2 = keyCache.get("https://example.com/kas", "rsa:2048", "test-kid-2"); assertNotNull(storedKASInfo2); assertEquals("https://example.com/kas", storedKASInfo2.URL); assertEquals("test-kid-2", storedKASInfo2.KID); @@ -849,7 +920,7 @@ void testStoreKeysToCache_MultipleKasEntries() { } GetAttributeValuesByFqnsResponse getResponseWithGrants(GetAttributeValuesByFqnsRequest req, - List grants) { + List grants) { GetAttributeValuesByFqnsResponse.Builder builder = GetAttributeValuesByFqnsResponse.newBuilder(); for (String v : req.getFqnsList()) { @@ -901,13 +972,15 @@ void testKeyCacheFromGrants() { AttributesServiceClient attributesServiceClient = mock(AttributesServiceClient.class); when(attributesServiceClient.getAttributeValuesByFqnsBlocking(any(), any())).thenAnswer(invocation -> { - var request = (GetAttributeValuesByFqnsRequest)invocation.getArgument(0); - return new UnaryBlockingCall(){ + var request = (GetAttributeValuesByFqnsRequest) invocation.getArgument(0); + return new UnaryBlockingCall() { @Override public ResponseMessage execute() { return new ResponseMessage.Success<>(getResponseWithGrants(request, List.of(kas1)), Collections.emptyMap(), Collections.emptyMap()); } - @Override public void cancel() { + + @Override + public void cancel() { // not really calling anything } }; @@ -920,14 +993,14 @@ public ResponseMessage execute() { assertThat(reasoner).isNotNull(); // Verify that the key was stored in the cache - Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1"); + Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1", "test-kid"); assertNotNull(storedKASInfo); assertEquals("https://example.com/kas", storedKASInfo.URL); assertEquals("test-kid", storedKASInfo.KID); assertEquals("ec:secp256r1", storedKASInfo.Algorithm); assertEquals("public-key-pem", storedKASInfo.PublicKey); - Config.KASInfo storedKASInfo2 = keyCache.get("https://example.com/kas", "rsa:2048"); + Config.KASInfo storedKASInfo2 = keyCache.get("https://example.com/kas", "rsa:2048", "test-kid-2"); assertNotNull(storedKASInfo2); assertEquals("https://example.com/kas", storedKASInfo2.URL); assertEquals("test-kid-2", storedKASInfo2.KID); @@ -935,4 +1008,171 @@ public ResponseMessage execute() { assertEquals("public-key-pem-2", storedKASInfo2.PublicKey); } + @Test + void testUsingBaseKeyWhenNoMappedKeysOrGrants() { + Autoconfigure.Granter granter = Autoconfigure.newGranterFromAttributes(null); + SimpleKasKey key = SimpleKasKey.newBuilder() + .setKasUri("https://example.com/kas") + .setPublicKey( + SimpleKasPublicKey.newBuilder() + .setKid("thenewkid") + .setPem("anotherpem") + .setAlgorithm(Algorithm.ALGORITHM_EC_P521) + ).build(); + + var splits = granter.getSplits( + List.of("https://example.org/kas2"), + () -> { + throw new IllegalStateException("the plan should have a single element"); + }, + () -> Optional.of(key)); + assertThat(splits).hasSize(1); + assertThat(splits.get(0)).isEqualTo(new Autoconfigure.KeySplitTemplate("https://example.com/kas", "", "thenewkid", KeyType.EC521Key)); + } + + @Test + void testUsingDefaultKasesWhenNothingElseProvided() { + Autoconfigure.Granter granter = Autoconfigure.newGranterFromAttributes(null); + var counter = new AtomicInteger(); + Supplier splitGen = () -> String.valueOf(counter.getAndIncrement()); + var splits = granter.getSplits( + List.of("https://example.org/kas1", "https://example.org/kas2"), + splitGen, + Optional::empty); + + assertThat(splits) + .hasSize(2) + .asList().containsExactly( + new Autoconfigure.KeySplitTemplate("https://example.org/kas1", "0", null, null), + new Autoconfigure.KeySplitTemplate("https://example.org/kas2", "1", null, null) + ); + } + + @Test + void createsGranterFromAttributeValues() { + // Arrange + Config.TDFConfig config = new Config.TDFConfig(); + config.attributeValues = List.of(mockValueFor(spk2spk), mockValueFor(rel2gbr)); + + SDK.Services services = mock(SDK.Services.class); + SDK.KAS kas = mock(SDK.KAS.class); + when(services.kas()).thenReturn(kas); + when(services.attributes()).thenThrow(new IllegalStateException("should never use the attribute service when attributes are provided")); + when(kas.getKeyCache()).thenReturn(null); // No cache needed for this test + + // Act + Autoconfigure.Granter granter = Autoconfigure.createGranter(services, config); + + // Assert + assertThat(granter).isNotNull(); + assertThat(granter.getPolicy()).hasSize(2); + assertThat(granter.getPolicy()).containsExactlyInAnyOrder( + new AttributeValueFQN("https://other.com/attr/specified/value/specked"), + new AttributeValueFQN("https://virtru.com/attr/Releasable%20To/value/GBR") + ); + } + + @Test + void createsGranterFromService() { + // Arrange + SDK.Services services = mock(SDK.Services.class); + SDK.KAS kas = mock(SDK.KAS.class); + AttributesServiceClient attributesServiceClient = mock(AttributesServiceClient.class); + + // Prepare a request and a mocked response + List policy = List.of( + new AttributeValueFQN("https://other.com/attr/specified/value/specked"), + new AttributeValueFQN("https://virtru.com/attr/Releasable%20To/value/GBR") + ); + + when(services.kas()).thenReturn(kas); + when(services.attributes()).thenReturn(attributesServiceClient); + + // Mock the attribute service to return a response with the expected values + when(attributesServiceClient.getAttributeValuesByFqnsBlocking(any(), any())).thenAnswer(invocation -> { + GetAttributeValuesByFqnsResponse.Builder builder = GetAttributeValuesByFqnsResponse.newBuilder(); + for (AttributeValueFQN fqn : policy) { + Value value = Value.newBuilder() + .setId(fqn.toString()) + .setFqn(fqn.toString()) + .build(); + builder.putFqnAttributeValues(fqn.toString(), + GetAttributeValuesByFqnsResponse.AttributeAndValue.newBuilder() + .setValue(value) + .build()); + } + return TestUtil.successfulUnaryCall(builder.build()); + }); + + // Act + Autoconfigure.Granter granter = Autoconfigure.createGranter(services, new Config.TDFConfig() {{ + attributeValues = null; // force use of service + attributes = policy; + }}); + + // Assert + assertThat(granter).isNotNull(); + // The policy should be empty because attributeValues is null, but the test ensures the service is called + // If you want to check the service call, verify it: + verify(services).attributes(); + } + + @Test + void getSplits_usesAutoconfigurePlan_whenAutoconfigureTrue() { + var tdfConfig = new Config.TDFConfig(); + tdfConfig.autoconfigure = true; + tdfConfig.kasInfoList = new ArrayList<>(); + tdfConfig.splitPlan = null; + + var kas = Mockito.mock(SDK.KAS.class); + Mockito.when(kas.getKeyCache()).thenReturn(new KASKeyCache()); + Config.KASInfo kasInfo = new Config.KASInfo() {{ + URL = "https://kas.example.com"; + Algorithm = "ec:secp256r1"; + KID = "kid"; + }}; + Mockito.when(kas.getPublicKey(any())).thenReturn(kasInfo); + + var services = new FakeServicesBuilder().setKas(kas).build(); + + // Mock granterFactory to return a granter with a known split plan + var expectedSplit = new Autoconfigure.KeySplitTemplate("https://kas.example.com", "", "kid", null); + var granter = Mockito.mock(Autoconfigure.Granter.class); + Mockito.when(granter.getSplits( + Mockito.anyList(), + Mockito.any(), + Mockito.any())) + .thenReturn(List.of(expectedSplit)); + + BiFunction granterFactory = + (s, c) -> granter; + + var planner = new Planner(tdfConfig, services, granterFactory); + + // Act + var splits = planner.getSplits(); + + // Assert + assertThat(splits).containsKey(""); + assertThat(splits.get("")).hasSize(1); + assertThat(splits.get("").get(0).URL).isEqualTo("https://kas.example.com"); + assertThat(splits.get("").get(0).KID).isEqualTo("kid"); + assertThat(splits.get("").get(0).Algorithm).isEqualTo("ec:secp256r1"); + } + + @Test + void testInvalidConfigurations() { + var config = new Config.TDFConfig(); + config.autoconfigure = true; + config.splitPlan = List.of(new KeySplitStep("kas1", "")); + Planner planner = new Planner(config, new FakeServicesBuilder().build(), (a, b) -> { throw new IllegalStateException("no way"); }); + Exception thrown = assertThrows(IllegalArgumentException.class, () -> planner.getSplits()); + assertThat(thrown.getMessage()).contains("cannot use autoconfigure with a split plan provided in the TDFConfig"); + + + config = new Config.TDFConfig() {{ autoconfigure = false; kasInfoList = Collections.EMPTY_LIST; splitPlan = null; }}; + var otherPlanner = new Planner(config, new FakeServicesBuilder().build(), (a, b) -> { throw new IllegalStateException("no way"); }); + thrown = assertThrows(SDK.KasInfoMissing.class, () -> otherPlanner.getSplits()); + assertThat(thrown.getMessage()).contains("no plan was constructed via autoconfigure, explicit split plan or provided kases"); + } } diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/FakeServices.java b/sdk/src/test/java/io/opentdf/platform/sdk/FakeServices.java index b3573593..2851b22b 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/FakeServices.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/FakeServices.java @@ -1,31 +1,34 @@ package io.opentdf.platform.sdk; -import io.opentdf.platform.authorization.AuthorizationServiceClient; -import io.opentdf.platform.policy.attributes.AttributesServiceClient; -import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClient; -import io.opentdf.platform.policy.namespaces.NamespaceServiceClient; -import io.opentdf.platform.policy.resourcemapping.ResourceMappingServiceClient; -import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceClient; +import io.opentdf.platform.authorization.AuthorizationServiceClientInterface; +import io.opentdf.platform.policy.attributes.AttributesServiceClientInterface; +import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClientInterface; +import io.opentdf.platform.policy.namespaces.NamespaceServiceClientInterface; +import io.opentdf.platform.policy.resourcemapping.ResourceMappingServiceClientInterface; +import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceClientInterface; +import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface; import java.util.Objects; public class FakeServices implements SDK.Services { - private final AuthorizationServiceClient authorizationService; - private final AttributesServiceClient attributesService; - private final NamespaceServiceClient namespaceService; - private final SubjectMappingServiceClient subjectMappingService; - private final ResourceMappingServiceClient resourceMappingService; - private final KeyAccessServerRegistryServiceClient keyAccessServerRegistryServiceFutureStub; + private final AuthorizationServiceClientInterface authorizationService; + private final AttributesServiceClientInterface attributesService; + private final NamespaceServiceClientInterface namespaceService; + private final SubjectMappingServiceClientInterface subjectMappingService; + private final ResourceMappingServiceClientInterface resourceMappingService; + private final KeyAccessServerRegistryServiceClientInterface keyAccessServerRegistryServiceFutureStub; + private final WellKnownServiceClientInterface wellKnownService; private final SDK.KAS kas; public FakeServices( - AuthorizationServiceClient authorizationService, - AttributesServiceClient attributesService, - NamespaceServiceClient namespaceService, - SubjectMappingServiceClient subjectMappingService, - ResourceMappingServiceClient resourceMappingService, - KeyAccessServerRegistryServiceClient keyAccessServerRegistryServiceFutureStub, + AuthorizationServiceClientInterface authorizationService, + AttributesServiceClientInterface attributesService, + NamespaceServiceClientInterface namespaceService, + SubjectMappingServiceClientInterface subjectMappingService, + ResourceMappingServiceClientInterface resourceMappingService, + KeyAccessServerRegistryServiceClientInterface keyAccessServerRegistryServiceFutureStub, + WellKnownServiceClientInterface wellKnownServiceClient, SDK.KAS kas) { this.authorizationService = authorizationService; this.attributesService = attributesService; @@ -33,39 +36,45 @@ public FakeServices( this.subjectMappingService = subjectMappingService; this.resourceMappingService = resourceMappingService; this.keyAccessServerRegistryServiceFutureStub = keyAccessServerRegistryServiceFutureStub; + this.wellKnownService = wellKnownServiceClient; this.kas = kas; } @Override - public AuthorizationServiceClient authorization() { + public AuthorizationServiceClientInterface authorization() { return Objects.requireNonNull(authorizationService); } @Override - public AttributesServiceClient attributes() { + public AttributesServiceClientInterface attributes() { return Objects.requireNonNull(attributesService); } @Override - public NamespaceServiceClient namespaces() { + public NamespaceServiceClientInterface namespaces() { return Objects.requireNonNull(namespaceService); } @Override - public SubjectMappingServiceClient subjectMappings() { + public SubjectMappingServiceClientInterface subjectMappings() { return Objects.requireNonNull(subjectMappingService); } @Override - public ResourceMappingServiceClient resourceMappings() { + public ResourceMappingServiceClientInterface resourceMappings() { return Objects.requireNonNull(resourceMappingService); } @Override - public KeyAccessServerRegistryServiceClient kasRegistry() { + public KeyAccessServerRegistryServiceClientInterface kasRegistry() { return Objects.requireNonNull(keyAccessServerRegistryServiceFutureStub); } + @Override + public WellKnownServiceClientInterface wellknown() { + return Objects.requireNonNull(wellKnownService); + } + @Override public SDK.KAS kas() { return Objects.requireNonNull(kas); diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/FakeServicesBuilder.java b/sdk/src/test/java/io/opentdf/platform/sdk/FakeServicesBuilder.java index 2a80f53d..558aee3b 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/FakeServicesBuilder.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/FakeServicesBuilder.java @@ -1,47 +1,54 @@ package io.opentdf.platform.sdk; -import io.opentdf.platform.authorization.AuthorizationServiceClient; -import io.opentdf.platform.policy.attributes.AttributesServiceClient; -import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClient; -import io.opentdf.platform.policy.namespaces.NamespaceServiceClient; -import io.opentdf.platform.policy.resourcemapping.ResourceMappingServiceClient; -import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceClient; +import io.opentdf.platform.authorization.AuthorizationServiceClientInterface; +import io.opentdf.platform.policy.attributes.AttributesServiceClientInterface; +import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClientInterface; +import io.opentdf.platform.policy.namespaces.NamespaceServiceClientInterface; +import io.opentdf.platform.policy.resourcemapping.ResourceMappingServiceClientInterface; +import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceClientInterface; +import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface; public class FakeServicesBuilder { - private AuthorizationServiceClient authorizationService; - private AttributesServiceClient attributesService; - private NamespaceServiceClient namespaceService; - private SubjectMappingServiceClient subjectMappingService; - private ResourceMappingServiceClient resourceMappingService; - private KeyAccessServerRegistryServiceClient keyAccessServerRegistryServiceFutureStub; + private AuthorizationServiceClientInterface authorizationService; + private AttributesServiceClientInterface attributesService; + private NamespaceServiceClientInterface namespaceService; + private SubjectMappingServiceClientInterface subjectMappingService; + private ResourceMappingServiceClientInterface resourceMappingService; + private KeyAccessServerRegistryServiceClientInterface keyAccessServerRegistryServiceFutureStub; + private WellKnownServiceClientInterface wellKnownServiceClient; private SDK.KAS kas; - public FakeServicesBuilder setAuthorizationService(AuthorizationServiceClient authorizationService) { + public FakeServicesBuilder setAuthorizationService(AuthorizationServiceClientInterface authorizationService) { this.authorizationService = authorizationService; return this; } - public FakeServicesBuilder setAttributesService(AttributesServiceClient attributesService) { + public FakeServicesBuilder setAttributesService(AttributesServiceClientInterface attributesService) { this.attributesService = attributesService; return this; } - public FakeServicesBuilder setNamespaceService(NamespaceServiceClient namespaceService) { + public FakeServicesBuilder setNamespaceService(NamespaceServiceClientInterface namespaceService) { this.namespaceService = namespaceService; return this; } - public FakeServicesBuilder setSubjectMappingService(SubjectMappingServiceClient subjectMappingService) { + public FakeServicesBuilder setSubjectMappingService(SubjectMappingServiceClientInterface subjectMappingService) { this.subjectMappingService = subjectMappingService; return this; } - public FakeServicesBuilder setResourceMappingService(ResourceMappingServiceClient resourceMappingService) { + public FakeServicesBuilder setResourceMappingService(ResourceMappingServiceClientInterface resourceMappingService) { this.resourceMappingService = resourceMappingService; return this; } - public FakeServicesBuilder setKeyAccessServerRegistryService(KeyAccessServerRegistryServiceClient keyAccessServerRegistryServiceFutureStub) { + public FakeServicesBuilder setWellknownService(WellKnownServiceClientInterface wellKnownServiceClient) { + this.wellKnownServiceClient = wellKnownServiceClient; + return this; + } + + public FakeServicesBuilder setKeyAccessServerRegistryService(KeyAccessServerRegistryServiceClientInterface keyAccessServerRegistryServiceFutureStub) { this.keyAccessServerRegistryServiceFutureStub = keyAccessServerRegistryServiceFutureStub; return this; } @@ -52,6 +59,7 @@ public FakeServicesBuilder setKas(SDK.KAS kas) { } public FakeServices build() { - return new FakeServices(authorizationService, attributesService, namespaceService, subjectMappingService, resourceMappingService, keyAccessServerRegistryServiceFutureStub, kas); + return new FakeServices(authorizationService, attributesService, namespaceService, subjectMappingService, + resourceMappingService, keyAccessServerRegistryServiceFutureStub, wellKnownServiceClient, kas); } } \ No newline at end of file diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/KASKeyCacheTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/KASKeyCacheTest.java index 5550678a..fdee682e 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/KASKeyCacheTest.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/KASKeyCacheTest.java @@ -35,7 +35,7 @@ void testStoreAndGet_WithinTimeLimit() { kasKeyCache.store(kasInfo1); // Retrieve the item within the time limit - Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048"); + Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048", "kid1"); // Ensure the item was correctly retrieved assertNotNull(result); @@ -51,12 +51,24 @@ void testStoreAndGet_AfterTimeLimit() { kasKeyCache.store(kasInfo1); // Simulate time passing by modifying the timestamp directly - KASKeyRequest cacheKey = new KASKeyRequest("https://example.com/kas1", "rsa:2048"); + KASKeyRequest cacheKey = new KASKeyRequest("https://example.com/kas1", "rsa:2048", "kid1"); TimeStampedKASInfo timeStampedKASInfo = new TimeStampedKASInfo(kasInfo1, LocalDateTime.now().minus(6, ChronoUnit.MINUTES)); kasKeyCache.cache.put(cacheKey, timeStampedKASInfo); // Attempt to retrieve the item after the time limit - Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048"); + Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048", "kid1"); + + // Ensure the item was not retrieved (it should have expired) + assertNull(result); + } + + @Test + void testStoreAndGet_DifferentKIDs() { + // Store an item in the cache + kasKeyCache.store(kasInfo1); + + // Attempt to retrieve the item with a different KID + Config.KASInfo result = kasKeyCache.get(kasInfo1.URL, kasInfo1.Algorithm, kasInfo1.KID + "different"); // Ensure the item was not retrieved (it should have expired) assertNull(result); @@ -72,7 +84,7 @@ void testStoreAndGet_WithNullAlgorithm() { kasKeyCache.store(kasInfo1); // Retrieve the item with a null algorithm - Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", null); + Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", null, "kid1"); // Ensure the item was correctly retrieved assertNotNull(result); @@ -91,7 +103,7 @@ void testClearCache() { kasKeyCache.clear(); // Attempt to retrieve the item after clearing the cache - Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048"); + Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048", "kid1"); // Ensure the item was not retrieved (the cache should be empty) assertNull(result); @@ -104,8 +116,8 @@ void testStoreMultipleItemsAndGet() { kasKeyCache.store(kasInfo2); // Retrieve each item and ensure they were correctly stored and retrieved - Config.KASInfo result1 = kasKeyCache.get("https://example.com/kas1", "rsa:2048"); - Config.KASInfo result2 = kasKeyCache.get("https://example.com/kas2", "ec:secp256r1"); + Config.KASInfo result1 = kasKeyCache.get("https://example.com/kas1", "rsa:2048", "kid1"); + Config.KASInfo result2 = kasKeyCache.get("https://example.com/kas2", "ec:secp256r1", "kid2"); assertNotNull(result1); assertEquals("https://example.com/kas1", result1.URL); @@ -119,8 +131,8 @@ void testStoreMultipleItemsAndGet() { @Test void testEqualsAndHashCode() { // Create two identical KASKeyRequest objects - KASKeyRequest keyRequest1 = new KASKeyRequest("https://example.com/kas1", "rsa:2048"); - KASKeyRequest keyRequest2 = new KASKeyRequest("https://example.com/kas1", "rsa:2048"); + KASKeyRequest keyRequest1 = new KASKeyRequest("https://example.com/kas1", "rsa:2048", "kid1"); + KASKeyRequest keyRequest2 = new KASKeyRequest("https://example.com/kas1", "rsa:2048", "kid1"); // Ensure that equals and hashCode work as expected assertEquals(keyRequest1, keyRequest2); diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/KeyTypeTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/KeyTypeTest.java new file mode 100644 index 00000000..6dc72f47 --- /dev/null +++ b/sdk/src/test/java/io/opentdf/platform/sdk/KeyTypeTest.java @@ -0,0 +1,44 @@ +package io.opentdf.platform.sdk; + +import org.junit.jupiter.api.Test; + +import static io.opentdf.platform.policy.Algorithm.ALGORITHM_EC_P256; +import static io.opentdf.platform.policy.Algorithm.ALGORITHM_EC_P384; +import static io.opentdf.platform.policy.Algorithm.ALGORITHM_EC_P521; +import static io.opentdf.platform.policy.Algorithm.ALGORITHM_RSA_2048; +import static io.opentdf.platform.policy.KasPublicKeyAlgEnum.KAS_PUBLIC_KEY_ALG_ENUM_EC_SECP256R1; +import static io.opentdf.platform.policy.KasPublicKeyAlgEnum.KAS_PUBLIC_KEY_ALG_ENUM_EC_SECP384R1; +import static io.opentdf.platform.policy.KasPublicKeyAlgEnum.KAS_PUBLIC_KEY_ALG_ENUM_EC_SECP521R1; +import static io.opentdf.platform.policy.KasPublicKeyAlgEnum.KAS_PUBLIC_KEY_ALG_ENUM_RSA_2048; +import static org.junit.jupiter.api.Assertions.*; + +class KeyTypeTest { + @Test + void testFromString() { + assertEquals(KeyType.RSA2048Key, KeyType.fromString("rsa:2048")); + assertEquals(KeyType.EC256Key, KeyType.fromString("ec:secp256r1")); + assertEquals(KeyType.EC384Key, KeyType.fromString("ec:secp384r1")); + assertEquals(KeyType.EC521Key, KeyType.fromString("ec:secp521r1")); + } + + @Test + void testFromStringInvalid() { + assertThrows(IllegalArgumentException.class, () -> KeyType.fromString("invalid:key")); + } + + @Test + void testFromAlgorithm() { + assertEquals(KeyType.RSA2048Key, KeyType.fromAlgorithm(ALGORITHM_RSA_2048)); + assertEquals(KeyType.EC256Key, KeyType.fromAlgorithm(ALGORITHM_EC_P256)); + assertEquals(KeyType.EC384Key, KeyType.fromAlgorithm(ALGORITHM_EC_P384)); + assertEquals(KeyType.EC521Key, KeyType.fromAlgorithm(ALGORITHM_EC_P521)); + } + + @Test + void testFromPublicKeyAlgEnum() { + assertEquals(KeyType.RSA2048Key, KeyType.fromPublicKeyAlgorithm(KAS_PUBLIC_KEY_ALG_ENUM_RSA_2048)); + assertEquals(KeyType.EC256Key, KeyType.fromPublicKeyAlgorithm(KAS_PUBLIC_KEY_ALG_ENUM_EC_SECP256R1)); + assertEquals(KeyType.EC384Key, KeyType.fromPublicKeyAlgorithm(KAS_PUBLIC_KEY_ALG_ENUM_EC_SECP384R1)); + assertEquals(KeyType.EC521Key, KeyType.fromPublicKeyAlgorithm(KAS_PUBLIC_KEY_ALG_ENUM_EC_SECP521R1)); + } +} \ No newline at end of file diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/PlannerTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/PlannerTest.java new file mode 100644 index 00000000..c7316dd6 --- /dev/null +++ b/sdk/src/test/java/io/opentdf/platform/sdk/PlannerTest.java @@ -0,0 +1,271 @@ +package io.opentdf.platform.sdk; + +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import io.opentdf.platform.policy.Algorithm; +import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse; +import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +class PlannerTest { + + @Test + void fetchBaseKey() { + var wellknownService = Mockito.mock(WellKnownServiceClientInterface.class); + var baseKeyJson = "{\"kas_url\":\"https://example.com/base_key\",\"public_key\":{\"algorithm\":\"ALGORITHM_RSA_2048\",\"kid\":\"thekid\",\"pem\": \"thepem\"}}"; + var val = Value.newBuilder().setStringValue(baseKeyJson).build(); + var config = Struct.newBuilder().putFields("base_key", val).build(); + var response = GetWellKnownConfigurationResponse + .newBuilder() + .setConfiguration(config) + .build(); + + Mockito.when(wellknownService.getWellKnownConfigurationBlocking(Mockito.any(), Mockito.anyMap())) + .thenReturn(TestUtil.successfulUnaryCall(response)); + + + var baseKey = Planner.fetchBaseKey(wellknownService); + assertThat(baseKey).isNotEmpty(); + var simpleKasKey = baseKey.get(); + assertThat(simpleKasKey.getKasUri()).isEqualTo("https://example.com/base_key"); + assertThat(simpleKasKey.getPublicKey().getAlgorithm()).isEqualTo(Algorithm.ALGORITHM_RSA_2048); + assertThat(simpleKasKey.getPublicKey().getKid()).isEqualTo("thekid"); + assertThat(simpleKasKey.getPublicKey().getPem()).isEqualTo("thepem"); + } + + @Test + void fetchBaseKeyWithNoBaseKey() { + var wellknownService = Mockito.mock(WellKnownServiceClientInterface.class); + var response = GetWellKnownConfigurationResponse + .newBuilder() + .setConfiguration(Struct.newBuilder().build()) + .build(); + + Mockito.when(wellknownService.getWellKnownConfigurationBlocking(Mockito.any(), Mockito.anyMap())) + .thenReturn(TestUtil.successfulUnaryCall(response)); + + var baseKey = Planner.fetchBaseKey(wellknownService); + assertThat(baseKey).isEmpty(); + } + + @Test + void fetchBaseKeyWithMissingFields() { + var wellknownService = Mockito.mock(WellKnownServiceClientInterface.class); + // Missing 'kid', 'pem', and 'algorithm' in public_key + var baseKeyJson = "{\"kas_url\":\"https://example.com/base_key\",\"public_key\":{}}"; + var val = Value.newBuilder().setStringValue(baseKeyJson).build(); + var config = Struct.newBuilder().putFields("base_key", val).build(); + var response = GetWellKnownConfigurationResponse + .newBuilder() + .setConfiguration(config) + .build(); + + Mockito.when(wellknownService.getWellKnownConfigurationBlocking(Mockito.any(), Mockito.anyMap())) + .thenReturn(TestUtil.successfulUnaryCall(response)); + + var baseKey = Planner.fetchBaseKey(wellknownService); + assertThat(baseKey).isEmpty(); + } + + @Test + void generatePlanFromProvidedKases() { + var kas1 = new Config.KASInfo(); + kas1.URL = "https://kas1.example.com"; + kas1.KID = "kid1"; + + var kas2 = new Config.KASInfo(); + kas2.URL = "https://kas2.example.com"; + kas2.KID = "kid2"; + kas2.Algorithm = "ec:secp256r1"; + + var tdfConfig = new Config.TDFConfig(); + tdfConfig.kasInfoList.add(kas1); + tdfConfig.kasInfoList.add(kas2); + + var planner = new Planner(tdfConfig, new FakeServicesBuilder().build(), (ignore1, ignored2) -> { throw new IllegalArgumentException("no granter needed"); }); + List splitPlan = planner.generatePlanFromProvidedKases(tdfConfig.kasInfoList); + + assertThat(splitPlan).asList().hasSize(2); + assertThat(splitPlan.get(0).kas).isEqualTo("https://kas1.example.com"); + assertThat(splitPlan.get(0).kid).isEqualTo("kid1"); + assertThat(splitPlan.get(0).keyType).isNull(); + + assertThat(splitPlan.get(1).kas).isEqualTo("https://kas2.example.com"); + assertThat(splitPlan.get(1).kid).isEqualTo("kid2"); + assertThat(splitPlan.get(1).keyType).isEqualTo(KeyType.EC256Key); + + assertThat(splitPlan.get(0).splitID).isNotEqualTo(splitPlan.get(1).splitID); + } + + @Test + void testFillingInKeysWithAutoConfigure() { + var kas = Mockito.mock(SDK.KAS.class); + Mockito.when(kas.getPublicKey(Mockito.any())).thenAnswer(invocation -> { + Config.KASInfo kasInfo = invocation.getArgument(0, Config.KASInfo.class); + var ret = new Config.KASInfo(); + ret.URL = kasInfo.URL; + if (Objects.equals(kasInfo.URL, "https://kas1.example.com")) { + ret.PublicKey = "pem1"; + ret.Algorithm = "rsa:2048"; + ret.KID = "kid1"; + } else if (Objects.equals(kasInfo.URL, "https://kas2.example.com")) { + ret.PublicKey = "pem2"; + ret.Algorithm = "ec:secp256r1"; + ret.KID = "kid2"; + } else if (Objects.equals(kasInfo.URL, "https://kas3.example.com")) { + ret.PublicKey = "pem3"; + ret.Algorithm = "ec:secp384r1"; + ret.KID = "kid3"; + assertThat(kasInfo.Algorithm).isEqualTo("ec:secp384r1"); + } else { + throw new IllegalArgumentException("Unexpected KAS URL: " + kasInfo.URL); + } + return ret; + }); + var tdfConfig = new Config.TDFConfig(); + tdfConfig.autoconfigure = true; + tdfConfig.wrappingKeyType = KeyType.RSA2048Key; + tdfConfig.kasInfoList = List.of( + new Config.KASInfo() {{ + URL = "https://kas4.example.com"; + KID = "kid4"; + Algorithm = "ec:secp384r1"; + PublicKey = "pem4"; + }} + ); + var planner = new Planner(tdfConfig, new FakeServicesBuilder().setKas(kas).build(), (ignore1, ignored2) -> { throw new IllegalArgumentException("no granter needed"); }); + var plan = List.of( + new Autoconfigure.KeySplitTemplate("https://kas1.example.com", "split1", null, null), + new Autoconfigure.KeySplitTemplate("https://kas4.example.com", "split1", "kid4", null), + new Autoconfigure.KeySplitTemplate("https://kas2.example.com", "split2", "kid2", null), + new Autoconfigure.KeySplitTemplate("https://kas3.example.com", "split2", null, KeyType.EC384Key) + ); + Map> filledInPlan = planner.resolveKeys(plan); + assertThat(filledInPlan.keySet().stream().collect(Collectors.toList())).asList().containsExactlyInAnyOrder("split1", "split2"); + assertThat(filledInPlan.get("split1")).asList().hasSize(2); + var kasInfo1 = filledInPlan.get("split1").stream().filter(k -> "kid1".equals(k.KID)).findFirst().get(); + assertThat(kasInfo1.URL).isEqualTo("https://kas1.example.com"); + assertThat(kasInfo1.Algorithm).isEqualTo("rsa:2048"); + assertThat(kasInfo1.PublicKey).isEqualTo("pem1"); + var kasInfo4 = filledInPlan.get("split1").stream().filter(k -> "kid4".equals(k.KID)).findFirst().get(); + assertThat(kasInfo4.URL).isEqualTo("https://kas4.example.com"); + assertThat(kasInfo4.Algorithm).isEqualTo("ec:secp384r1"); + assertThat(kasInfo4.PublicKey).isEqualTo("pem4"); + + assertThat(filledInPlan.get("split2")).asList().hasSize(2); + var kasInfo2 = filledInPlan.get("split2").stream().filter(kasInfo -> "kid2".equals(kasInfo.KID)).findFirst().get(); + assertThat(kasInfo2.URL).isEqualTo("https://kas2.example.com"); + assertThat(kasInfo2.Algorithm).isEqualTo("ec:secp256r1"); + assertThat(kasInfo2.PublicKey).isEqualTo("pem2"); + var kasInfo3 = filledInPlan.get("split2").stream().filter(kasInfo -> "kid3".equals(kasInfo.KID)).findFirst().get(); + assertThat(kasInfo3.URL).isEqualTo("https://kas3.example.com"); + assertThat(kasInfo3.Algorithm).isEqualTo("ec:secp384r1"); + assertThat(kasInfo3.PublicKey).isEqualTo("pem3"); + } + + @Test + void returnsOnlyDefaultKasesIfPresent() { + var kas1 = new Config.KASInfo(); + kas1.URL = "https://kas1.example.com"; + kas1.Default = true; + + var kas2 = new Config.KASInfo(); + kas2.URL = "https://kas2.example.com"; + kas2.Default = false; + + var kas3 = new Config.KASInfo(); + kas3.URL = "https://kas3.example.com"; + kas3.Default = true; + + var config = new Config.TDFConfig(); + config.kasInfoList.addAll(List.of(kas1, kas2, kas3)); + + List result = Planner.defaultKases(config); + + Assertions.assertThat(result).containsExactlyInAnyOrder("https://kas1.example.com", "https://kas3.example.com"); + } + + @Test + void returnsAllKasesIfNoDefault() { + var kas1 = new Config.KASInfo(); + kas1.URL = "https://kas1.example.com"; + kas1.Default = false; + + var kas2 = new Config.KASInfo(); + kas2.URL = "https://kas2.example.com"; + kas2.Default = null; // not set + + var config = new Config.TDFConfig(); + config.kasInfoList.addAll(List.of(kas1, kas2)); + + List result = Planner.defaultKases(config); + Assertions.assertThat(result).containsExactlyInAnyOrder("https://kas1.example.com", "https://kas2.example.com"); + } + + @Test + void returnsEmptyListIfNoKases() { + var config = new Config.TDFConfig(); + List result = Planner.defaultKases(config); + Assertions.assertThat(result).isEmpty(); + } + + @Test + void usesProvidedSplitPlanWhenNotAutoconfigure() { + var kas = Mockito.mock(SDK.KAS.class); + Mockito.when(kas.getPublicKey(Mockito.any())).thenAnswer(invocation -> { + Config.KASInfo kasInfo = invocation.getArgument(0, Config.KASInfo.class); + var ret = new Config.KASInfo(); + ret.URL = kasInfo.URL; + if (Objects.equals(kasInfo.URL, "https://kas1.example.com")) { + ret.PublicKey = "pem1"; + ret.Algorithm = "rsa:2048"; + ret.KID = "kid1"; + } else if (Objects.equals(kasInfo.URL, "https://kas2.example.com")) { + ret.PublicKey = "pem2"; + ret.Algorithm = "ec:secp256r1"; + ret.KID = "kid2"; + } else { + throw new IllegalArgumentException("Unexpected KAS URL: " + kasInfo.URL); + } + return ret; + }); + // Arrange + var kas1 = new Config.KASInfo(); + kas1.URL = "https://kas1.example.com"; + kas1.KID = "kid1"; + kas1.Algorithm = "rsa:2048"; + + var kas2 = new Config.KASInfo(); + kas2.URL = "https://kas2.example.com"; + kas2.KID = "kid2"; + kas2.Algorithm = "ec:secp256"; + + var splitStep1 = new Autoconfigure.KeySplitStep(kas1.URL, "split1"); + var splitStep2 = new Autoconfigure.KeySplitStep(kas2.URL, "split2"); + + var tdfConfig = new Config.TDFConfig(); + tdfConfig.autoconfigure = false; + tdfConfig.kasInfoList.add(kas1); + tdfConfig.kasInfoList.add(kas2); + tdfConfig.splitPlan = List.of(splitStep1, splitStep2); + + var planner = new Planner(tdfConfig, new FakeServicesBuilder().setKas(kas).build(), (ignore1, ignored2) -> { throw new IllegalArgumentException("no granter needed"); }); + + // Act + Map> splits = planner.getSplits(); + + // Assert + Assertions.assertThat(splits).hasSize(2); + Assertions.assertThat(splits.get("split1")).extracting("URL").containsExactly("https://kas1.example.com"); + Assertions.assertThat(splits.get("split2")).extracting("URL").containsExactly("https://kas2.example.com"); + } +} \ No newline at end of file diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/TestUtil.java b/sdk/src/test/java/io/opentdf/platform/sdk/TestUtil.java new file mode 100644 index 00000000..53076007 --- /dev/null +++ b/sdk/src/test/java/io/opentdf/platform/sdk/TestUtil.java @@ -0,0 +1,22 @@ +package io.opentdf.platform.sdk; + +import com.connectrpc.ResponseMessage; +import com.connectrpc.UnaryBlockingCall; + +import java.util.Collections; + +public class TestUtil { + static UnaryBlockingCall successfulUnaryCall(T result) { + return new UnaryBlockingCall() { + @Override + public ResponseMessage execute() { + return new ResponseMessage.Success<>(result, Collections.emptyMap(), Collections.emptyMap()); + } + + @Override + public void cancel() { + // in tests we don't need to preserve server resources, so no-op + } + }; + } +} diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/ZipReaderTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/ZipReaderTest.java index a66e0f84..087743db 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/ZipReaderTest.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/ZipReaderTest.java @@ -146,5 +146,7 @@ public void testReadingAndWritingRandomFiles() throws IOException { entry.getData().transferTo(zipData); assertThat(zipData.toByteArray()).isEqualTo(namesToData.get(entry.getName())); } + + assertThat(reader.getEntries().size()).isEqualTo(namesToData.size()); } } \ No newline at end of file