diff --git a/lib/src/concurrency.ts b/lib/src/concurrency.ts new file mode 100644 index 00000000..a461e0ca --- /dev/null +++ b/lib/src/concurrency.ts @@ -0,0 +1,96 @@ +type LabelledSuccess = { lid: string; value: Promise }; +type LabelledFailure = { lid: string; e: any }; + +async function labelPromise(label: string, promise: Promise): Promise> { + try { + const value = await promise; + return { lid: label, value: Promise.resolve(value) }; + } catch (e) { + throw { lid: label, e }; + } +} + +// Pooled variant of Promise.all; implements most of the logic of the real all, +// but with a pool size of n. Rejects on first reject, or returns a list +// of all successful responses. Operates with at most n 'active' promises at a time. +// For tracking purposes, all promises must have a unique identifier. +export async function allPool(n: number, p: Record>): Promise[]> { + const pool: Record>> = {}; + const resolved: Awaited[] = []; + for (const [id, job] of Object.entries(p)) { + // while the size of jobs to do is greater than n, + // let n jobs run and take the first one to finish out of the pool + pool[id] = labelPromise(id, job); + if (Object.keys(pool).length > n - 1) { + const promises = Object.values(pool); + try { + const { lid, value } = await Promise.race(promises); + resolved.push(await value); + console.log(`succeeded on promise ${lid}`, value); + delete pool[lid]; + } catch (err) { + const { lid, e } = err as LabelledFailure; + console.warn(`failed on promise ${lid}`, err); + throw e; + } + } + } + try { + for (const labelled of await Promise.all(Object.values(pool))) { + console.log(`real.all succeeded on promise ${labelled.lid}`, labelled); + resolved.push(await labelled.value); + } + } catch (err) { + if ('lid' in err && 'e' in err) { + throw err.e; + } else { + throw err; + } + } + return resolved; +} + +// Pooled variant of promise.any; implements most of the logic of the real any, +// but with a pool size of n, and returns the first successful promise, +// operating with at most n 'active' promises at a time. +export async function anyPool(n: number, p: Record>): Promise> { + const pool: Record>> = {}; + const rejections = []; + for (const [id, job] of Object.entries(p)) { + // while the size of jobs to do is greater than n, + // let n jobs run and take the first one to finish out of the pool + pool[id] = labelPromise(id, job); + if (Object.keys(pool).length > n - 1) { + const promises = Object.values(pool); + try { + const { lid, value } = await Promise.race(promises); + console.log(`any succeeded on promise ${lid}`, value); + return await value; + } catch (error) { + const { lid, e } = error; + rejections.push(e); + delete pool[lid]; + console.log(`any failed on promise ${lid}`, e); + } + } + } + try { + const { lid, value } = await Promise.any(Object.values(pool)); + console.log(`real.any succeeded on promise ${lid}`); + return await value; + } catch (errors) { + console.log(`real.any failed`, errors); + if (errors instanceof AggregateError) { + for (const error of errors.errors) { + if ('lid' in error && 'e' in error) { + rejections.push(error.e); + } else { + rejections.push(error); + } + } + } else { + rejections.push(errors); + } + } + throw new AggregateError(rejections); +} diff --git a/lib/tdf3/src/client/builders.ts b/lib/tdf3/src/client/builders.ts index a8e96d7c..37bb291b 100644 --- a/lib/tdf3/src/client/builders.ts +++ b/lib/tdf3/src/client/builders.ts @@ -519,6 +519,7 @@ export type DecryptParams = { keyMiddleware?: DecryptKeyMiddleware; streamMiddleware?: DecryptStreamMiddleware; assertionVerificationKeys?: AssertionVerificationKeys; + concurrencyLimit?: number; noVerifyAssertions?: boolean; }; @@ -685,6 +686,11 @@ class DecryptParamsBuilder { return freeze({ ..._params }); } + withConcurrencyLimit(limit: number): DecryptParamsBuilder { + this._params.concurrencyLimit = limit; + return this; + } + /** * Generate a parameters object in the form expected by {@link Client#decrypt|decrypt}. *

diff --git a/lib/tdf3/src/client/index.ts b/lib/tdf3/src/client/index.ts index 658fe04c..176e699d 100644 --- a/lib/tdf3/src/client/index.ts +++ b/lib/tdf3/src/client/index.ts @@ -562,6 +562,7 @@ export class Client { streamMiddleware = async (stream: DecoratedReadableStream) => stream, assertionVerificationKeys, noVerifyAssertions, + concurrencyLimit = 1, }: DecryptParams): Promise { const dpopKeys = await this.dpopKeys; let entityObject; @@ -587,6 +588,7 @@ export class Client { allowList: this.allowedKases, authProvider: this.authProvider, chunker, + concurrencyLimit, cryptoService: this.cryptoService, dpopKeys, entity: entityObject, diff --git a/lib/tdf3/src/tdf.ts b/lib/tdf3/src/tdf.ts index c99a370d..8fd6eb7d 100644 --- a/lib/tdf3/src/tdf.ts +++ b/lib/tdf3/src/tdf.ts @@ -65,6 +65,7 @@ import PolicyObject from '../../src/tdf/PolicyObject.js'; import { type CryptoService, type DecryptResult } from './crypto/declarations.js'; import { CentralDirectory } from './utils/zip-reader.js'; import { SymmetricCipher } from './ciphers/symmetric-cipher-base.js'; +import { allPool, anyPool } from '../../src/concurrency.js'; // TODO: input validation on manifest JSON const DEFAULT_SEGMENT_SIZE = 1024 * 1024; @@ -163,6 +164,7 @@ export type DecryptConfiguration = { fileStreamServiceWorker?: string; assertionVerificationKeys?: AssertionVerificationKeys; noVerifyAssertions?: boolean; + concurrencyLimit?: number; }; export type UpsertConfiguration = { @@ -904,17 +906,24 @@ export function splitLookupTableFactory( return splitPotentials; } +type RewrapResponseData = { + key: Uint8Array; + metadata: Record; +}; + async function unwrapKey({ manifest, allowedKases, authProvider, dpopKeys, + concurrencyLimit, entity, cryptoService, }: { manifest: Manifest; allowedKases: OriginAllowList; authProvider: AuthProvider | AppIdAuthProvider; + concurrencyLimit?: number; dpopKeys: CryptoKeyPair; entity: EntityObject | undefined; cryptoService: CryptoService; @@ -928,7 +937,7 @@ async function unwrapKey({ const splitPotentials = splitLookupTableFactory(keyAccess, allowedKases); const isAppIdProvider = authProvider && isAppIdProviderCheck(authProvider); - async function tryKasRewrap(keySplitInfo: KeyAccessObject) { + async function tryKasRewrap(keySplitInfo: KeyAccessObject): Promise { const url = `${keySplitInfo.url}/${isAppIdProvider ? '' : 'v2/'}rewrap`; const ephemeralEncryptionKeys = await cryptoService.cryptoToPemPair( await cryptoService.generateKeyPair() @@ -982,77 +991,44 @@ async function unwrapKey({ }; } - // Get unique split IDs to determine if we have an OR or AND condition - const splitIds = new Set(Object.keys(splitPotentials)); - - // If we have only one split ID, it's an OR condition - if (splitIds.size === 1) { - const [splitId] = splitIds; + const poolSize = concurrencyLimit === undefined ? 1 : concurrencyLimit > 1 ? concurrencyLimit : 1; + const splitPromises: Record> = {}; + for (const splitId of Object.keys(splitPotentials)) { const potentials = splitPotentials[splitId]; - - try { - // OR condition: Try all KAS servers for this split, take first success - const result = await Promise.any( - Object.values(potentials).map(async (keySplitInfo) => { - try { - return await tryKasRewrap(keySplitInfo); - } catch (e) { - // Rethrow with more context - throw handleRewrapError(e as Error | AxiosError); - } - }) + if (!potentials || !Object.keys(potentials).length) { + throw new UnsafeUrlError( + `Unreconstructable key - no valid KAS found for split ${JSON.stringify(splitId)}`, + '' ); - - const reconstructedKey = keyMerge([result.key]); - return { - reconstructedKeyBinary: Binary.fromArrayBuffer(reconstructedKey), - metadata: result.metadata, - }; - } catch (error) { - if (error instanceof AggregateError) { - // All KAS servers failed - throw error.errors[0]; // Throw the first error since we've already wrapped them - } - throw error; } - } else { - // AND condition: We need successful results from all different splits - const splitResults = await Promise.all( - Object.entries(splitPotentials).map(async ([splitId, potentials]) => { - if (!potentials || !Object.keys(potentials).length) { - throw new UnsafeUrlError( - `Unreconstructable key - no valid KAS found for split ${JSON.stringify(splitId)}`, - '' - ); - } - + const anyPromises: Record> = {}; + for (const [kas, keySplitInfo] of Object.entries(potentials)) { + anyPromises[kas] = (async () => { try { - // For each split, try all potential KAS servers until one succeeds - return await Promise.any( - Object.values(potentials).map(async (keySplitInfo) => { - try { - return await tryKasRewrap(keySplitInfo); - } catch (e) { - throw handleRewrapError(e as Error | AxiosError); - } - }) - ); - } catch (error) { - if (error instanceof AggregateError) { - // All KAS servers for this split failed - throw error.errors[0]; // Throw the first error since we've already wrapped them - } - throw error; + return await tryKasRewrap(keySplitInfo); + } catch (e) { + throw handleRewrapError(e as Error | AxiosError); } - }) - ); - + })(); + } + splitPromises[splitId] = anyPool(poolSize, anyPromises); + } + try { + const splitResults = await allPool(poolSize, splitPromises); // Merge all the split keys const reconstructedKey = keyMerge(splitResults.map((r) => r.key)); return { reconstructedKeyBinary: Binary.fromArrayBuffer(reconstructedKey), metadata: splitResults[0].metadata, // Use metadata from first split }; + } catch (e) { + if (e instanceof AggregateError) { + const errors = e.errors; + if (errors.length === 1) { + throw errors[0]; + } + } + throw e; } } diff --git a/lib/tests/mocha/unit/concurrency.spec.ts b/lib/tests/mocha/unit/concurrency.spec.ts new file mode 100644 index 00000000..bb415690 --- /dev/null +++ b/lib/tests/mocha/unit/concurrency.spec.ts @@ -0,0 +1,65 @@ +import { allPool, anyPool } from '../../../src/concurrency.js'; +import { expect } from 'chai'; + +describe('concurrency', () => { + for (const n of [1, 2, 3, 4]) { + describe(`allPool(${n})`, () => { + it(`should resolve all promises with a pool size of ${n}`, async () => { + const promises = { + a: Promise.resolve(1), + b: Promise.resolve(2), + c: Promise.resolve(3), + }; + const result = await allPool(n, promises); + expect(result).to.have.members([1, 2, 3]); + }); + it(`should reject if any promise rejects, n=${n}`, async () => { + const promises = { + a: Promise.resolve(1), + b: Promise.reject(new Error('failure')), + c: Promise.resolve(3), + }; + try { + await allPool(n, promises); + } catch (e) { + expect(e).to.contain({ message: 'failure' }); + } + }); + }); + describe(`anyPool(${n})`, () => { + it('should resolve with the first resolved promise', async () => { + const startTime = Date.now(); + const promises = { + a: new Promise((resolve) => setTimeout(() => resolve(1), 500)), + b: new Promise((resolve) => setTimeout(() => resolve(2), 50)), + c: new Promise((resolve) => setTimeout(() => resolve(3), 1500)), + }; + const result = await anyPool(n, promises); + const endTime = Date.now(); + const elapsed = endTime - startTime; + if (n > 1) { + expect(elapsed).to.be.lessThan(500); + expect(result).to.equal(2); + } else { + expect(elapsed).to.be.greaterThan(50); + expect(elapsed).to.be.lessThan(1000); + expect(result).to.equal(1); + } + }); + + it('should reject if all promises reject', async () => { + const promises = { + a: Promise.reject(new Error('failure1')), + b: Promise.reject(new Error('failure2')), + c: Promise.reject(new Error('failure3')), + }; + try { + await anyPool(n, promises); + } catch (e) { + expect(e).to.be.instanceOf(AggregateError); + expect(e.errors).to.have.lengthOf(3); + } + }); + }); + } +});