diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index 6dff9b1283..f61c8b4652 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -55,7 +55,7 @@ interface IPulse is PulseEvents { address provider, uint64 publishTime, bytes32[] calldata priceIds, - uint256 callbackGasLimit + uint32 callbackGasLimit ) external payable returns (uint64 sequenceNumber); /** @@ -80,7 +80,7 @@ interface IPulse is PulseEvents { * @dev This is a fixed fee per request that goes to the Pyth protocol, separate from gas costs * @return pythFeeInWei The base fee in wei that every request must pay */ - function getPythFeeInWei() external view returns (uint128 pythFeeInWei); + function getPythFeeInWei() external view returns (uint96 pythFeeInWei); /** * @notice Calculates the total fee required for a price update request @@ -92,9 +92,9 @@ interface IPulse is PulseEvents { */ function getFee( address provider, - uint256 callbackGasLimit, + uint32 callbackGasLimit, bytes32[] calldata priceIds - ) external view returns (uint128 feeAmount); + ) external view returns (uint96 feeAmount); function getAccruedPythFees() external @@ -116,16 +116,16 @@ interface IPulse is PulseEvents { function withdrawAsFeeManager(address provider, uint128 amount) external; function registerProvider( - uint128 baseFeeInWei, - uint128 feePerFeedInWei, - uint128 feePerGasInWei + uint96 baseFeeInWei, + uint96 feePerFeedInWei, + uint96 feePerGasInWei ) external; function setProviderFee( address provider, - uint128 newBaseFeeInWei, - uint128 newFeePerFeedInWei, - uint128 newFeePerGasInWei + uint96 newBaseFeeInWei, + uint96 newFeePerFeedInWei, + uint96 newFeePerGasInWei ) external; function getProviderInfo( @@ -136,9 +136,9 @@ interface IPulse is PulseEvents { function setDefaultProvider(address provider) external; - function setExclusivityPeriod(uint256 periodSeconds) external; + function setExclusivityPeriod(uint32 periodSeconds) external; - function getExclusivityPeriod() external view returns (uint256); + function getExclusivityPeriod() external view returns (uint32); /** * @notice Gets the first N active requests diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 9483bc015c..1f72ac3646 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -11,11 +11,11 @@ import "./PulseErrors.sol"; abstract contract Pulse is IPulse, PulseState { function _initialize( address admin, - uint128 pythFeeInWei, + uint96 pythFeeInWei, address pythAddress, address defaultProvider, bool prefillRequestStorage, - uint256 exclusivityPeriodSeconds + uint32 exclusivityPeriodSeconds ) internal { require(admin != address(0), "admin is zero address"); require(pythAddress != address(0), "pyth is zero address"); @@ -44,11 +44,6 @@ abstract contract Pulse is IPulse, PulseState { req.publishTime = 1; req.callbackGasLimit = 1; req.requester = address(1); - req.numPriceIds = 0; - // Pre-warm the priceIds array storage - for (uint8 j = 0; j < MAX_PRICE_IDS; j++) { - req.priceIds[j] = bytes32(0); - } } } } @@ -58,7 +53,7 @@ abstract contract Pulse is IPulse, PulseState { address provider, uint64 publishTime, bytes32[] calldata priceIds, - uint256 callbackGasLimit + uint32 callbackGasLimit ) external payable override returns (uint64 requestSequenceNumber) { require( _state.providers[provider].isRegistered, @@ -77,7 +72,7 @@ abstract contract Pulse is IPulse, PulseState { } requestSequenceNumber = _state.currentSequenceNumber++; - uint128 requiredFee = getFee(provider, callbackGasLimit, priceIds); + uint96 requiredFee = getFee(provider, callbackGasLimit, priceIds); if (msg.value < requiredFee) revert InsufficientFee(); Request storage req = allocRequest(requestSequenceNumber); @@ -85,13 +80,21 @@ abstract contract Pulse is IPulse, PulseState { req.publishTime = publishTime; req.callbackGasLimit = callbackGasLimit; req.requester = msg.sender; - req.numPriceIds = uint8(priceIds.length); req.provider = provider; - req.fee = SafeCast.toUint128(msg.value - _state.pythFeeInWei); + req.fee = SafeCast.toUint96(msg.value - _state.pythFeeInWei); - // Copy price IDs to storage + // Create array with the right size + req.priceIdPrefixes = new bytes8[](priceIds.length); + + // Copy only the first 8 bytes of each price ID to storage for (uint8 i = 0; i < priceIds.length; i++) { - req.priceIds[i] = priceIds[i]; + // Extract first 8 bytes of the price ID + bytes32 priceId = priceIds[i]; + bytes8 prefix; + assembly { + prefix := priceId + } + req.priceIdPrefixes[i] = prefix; } _state.accruedFeesInWei += _state.pythFeeInWei; @@ -119,12 +122,21 @@ abstract contract Pulse is IPulse, PulseState { // Verify priceIds match require( - priceIds.length == req.numPriceIds, + priceIds.length == req.priceIdPrefixes.length, "Price IDs length mismatch" ); - for (uint8 i = 0; i < req.numPriceIds; i++) { - if (priceIds[i] != req.priceIds[i]) { - revert InvalidPriceIds(priceIds[i], req.priceIds[i]); + for (uint8 i = 0; i < req.priceIdPrefixes.length; i++) { + // Extract first 8 bytes of the provided price ID + bytes32 priceId = priceIds[i]; + bytes8 prefix; + assembly { + prefix := priceId + } + + // Compare with stored prefix + if (prefix != req.priceIdPrefixes[i]) { + // Now we can directly use the bytes8 prefix in the error + revert InvalidPriceIds(priceIds[i], req.priceIdPrefixes[i]); } } @@ -222,31 +234,31 @@ abstract contract Pulse is IPulse, PulseState { function getFee( address provider, - uint256 callbackGasLimit, + uint32 callbackGasLimit, bytes32[] calldata priceIds - ) public view override returns (uint128 feeAmount) { - uint128 baseFee = _state.pythFeeInWei; // Fixed fee to Pyth + ) public view override returns (uint96 feeAmount) { + uint96 baseFee = _state.pythFeeInWei; // Fixed fee to Pyth // Note: The provider needs to set its fees to include the fee charged by the Pyth contract. // Ideally, we would be able to automatically compute the pyth fees from the priceIds, but the // fee computation on IPyth assumes it has the full updated data. - uint128 providerBaseFee = _state.providers[provider].baseFeeInWei; - uint128 providerFeedFee = SafeCast.toUint128( + uint96 providerBaseFee = _state.providers[provider].baseFeeInWei; + uint96 providerFeedFee = SafeCast.toUint96( priceIds.length * _state.providers[provider].feePerFeedInWei ); - uint128 providerFeeInWei = _state.providers[provider].feePerGasInWei; // Provider's per-gas rate + uint96 providerFeeInWei = _state.providers[provider].feePerGasInWei; // Provider's per-gas rate uint256 gasFee = callbackGasLimit * providerFeeInWei; // Total provider fee based on gas feeAmount = baseFee + providerBaseFee + providerFeedFee + - SafeCast.toUint128(gasFee); // Total fee user needs to pay + SafeCast.toUint96(gasFee); // Total fee user needs to pay } function getPythFeeInWei() public view override - returns (uint128 pythFeeInWei) + returns (uint96 pythFeeInWei) { pythFeeInWei = _state.pythFeeInWei; } @@ -367,9 +379,9 @@ abstract contract Pulse is IPulse, PulseState { } function registerProvider( - uint128 baseFeeInWei, - uint128 feePerFeedInWei, - uint128 feePerGasInWei + uint96 baseFeeInWei, + uint96 feePerFeedInWei, + uint96 feePerGasInWei ) external override { ProviderInfo storage provider = _state.providers[msg.sender]; require(!provider.isRegistered, "Provider already registered"); @@ -382,9 +394,9 @@ abstract contract Pulse is IPulse, PulseState { function setProviderFee( address provider, - uint128 newBaseFeeInWei, - uint128 newFeePerFeedInWei, - uint128 newFeePerGasInWei + uint96 newBaseFeeInWei, + uint96 newFeePerFeedInWei, + uint96 newFeePerGasInWei ) external override { require( _state.providers[provider].isRegistered, @@ -396,9 +408,9 @@ abstract contract Pulse is IPulse, PulseState { "Only provider or fee manager can invoke this method" ); - uint128 oldBaseFee = _state.providers[provider].baseFeeInWei; - uint128 oldFeePerFeed = _state.providers[provider].feePerFeedInWei; - uint128 oldFeePerGas = _state.providers[provider].feePerGasInWei; + uint96 oldBaseFee = _state.providers[provider].baseFeeInWei; + uint96 oldFeePerFeed = _state.providers[provider].feePerFeedInWei; + uint96 oldFeePerGas = _state.providers[provider].feePerGasInWei; _state.providers[provider].baseFeeInWei = newBaseFeeInWei; _state.providers[provider].feePerFeedInWei = newFeePerFeedInWei; _state.providers[provider].feePerGasInWei = newFeePerGasInWei; @@ -437,7 +449,7 @@ abstract contract Pulse is IPulse, PulseState { emit DefaultProviderUpdated(oldProvider, provider); } - function setExclusivityPeriod(uint256 periodSeconds) external override { + function setExclusivityPeriod(uint32 periodSeconds) external override { require( msg.sender == _state.admin, "Only admin can set exclusivity period" @@ -447,7 +459,7 @@ abstract contract Pulse is IPulse, PulseState { emit ExclusivityPeriodUpdated(oldPeriod, periodSeconds); } - function getExclusivityPeriod() external view override returns (uint256) { + function getExclusivityPeriod() external view override returns (uint32) { return _state.exclusivityPeriodSeconds; } diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol index e57719d4da..ebbb6ef332 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol @@ -9,7 +9,7 @@ error InsufficientFee(); error Unauthorized(); error InvalidCallbackGas(); error CallbackFailed(); -error InvalidPriceIds(bytes32 providedPriceIdsHash, bytes32 storedPriceIdsHash); +error InvalidPriceIds(bytes32 providedPriceId, bytes8 storedPriceId); error InvalidCallbackGasLimit(uint256 requested, uint256 stored); error ExceedsMaxPrices(uint32 requested, uint32 maxAllowed); error TooManyPriceIds(uint256 provided, uint256 maximum); diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol index f01069e60d..27c0b37f1d 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol @@ -32,15 +32,15 @@ interface PulseEvents { address newFeeManager ); - event ProviderRegistered(address indexed provider, uint128 feeInWei); + event ProviderRegistered(address indexed provider, uint96 feeInWei); event ProviderFeeUpdated( address indexed provider, - uint128 oldBaseFee, - uint128 oldFeePerFeed, - uint128 oldFeePerGas, - uint128 newBaseFee, - uint128 newFeePerFeed, - uint128 newFeePerGas + uint96 oldBaseFee, + uint96 oldFeePerFeed, + uint96 oldFeePerGas, + uint96 newBaseFee, + uint96 newFeePerFeed, + uint96 newFeePerGas ); event DefaultProviderUpdated(address oldProvider, address newProvider); diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol index 57560d276a..fa5664320f 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol @@ -10,41 +10,62 @@ contract PulseState { uint8 public constant MAX_PRICE_IDS = 10; struct Request { + // Slot 1: 8 + 8 + 4 + 12 = 32 bytes uint64 sequenceNumber; uint64 publishTime; - // TODO: this is going to absolutely explode gas costs. Need to do something smarter here. - // possible solution is to hash the price ids and store the hash instead. - // The ids themselves can be retrieved from the event. - bytes32[MAX_PRICE_IDS] priceIds; - uint8 numPriceIds; // Actual number of price IDs used - uint256 callbackGasLimit; + uint32 callbackGasLimit; + uint96 fee; + // Slot 2: 20 + 12 = 32 bytes address requester; + // 12 bytes padding + + // Slot 3: 20 + 12 = 32 bytes address provider; - uint128 fee; + // 12 bytes padding + + // Dynamic array starts at its own slot + // Store only first 8 bytes of each price ID to save gas + bytes8[] priceIdPrefixes; } struct ProviderInfo { - uint128 baseFeeInWei; - uint128 feePerFeedInWei; - uint128 feePerGasInWei; + // Slot 1: 12 + 12 + 8 = 32 bytes + uint96 baseFeeInWei; + uint96 feePerFeedInWei; + // 8 bytes padding + + // Slot 2: 12 + 16 + 4 = 32 bytes + uint96 feePerGasInWei; uint128 accruedFeesInWei; + // 4 bytes padding + + // Slot 3: 20 + 1 + 11 = 32 bytes address feeManager; bool isRegistered; + // 11 bytes padding } struct State { + // Slot 1: 20 + 4 + 8 = 32 bytes address admin; - uint128 pythFeeInWei; - uint128 accruedFeesInWei; - address pyth; + uint32 exclusivityPeriodSeconds; uint64 currentSequenceNumber; + // Slot 2: 20 + 8 + 4 = 32 bytes + address pyth; + uint64 firstUnfulfilledSeq; + // 4 bytes padding + + // Slot 3: 20 + 12 = 32 bytes address defaultProvider; - uint256 exclusivityPeriodSeconds; + uint96 pythFeeInWei; + // Slot 4: 16 + 16 = 32 bytes + uint128 accruedFeesInWei; + // 16 bytes padding + + // These take their own slots regardless of ordering Request[NUM_REQUESTS] requests; mapping(bytes32 => Request) requestsOverflow; mapping(address => ProviderInfo) providers; - uint64 firstUnfulfilledSeq; // All sequences before this are fulfilled } - State internal _state; } diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol index f3ceafc5ed..98ec6143a7 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol @@ -21,11 +21,11 @@ contract PulseUpgradeable is function initialize( address owner, address admin, - uint128 pythFeeInWei, + uint96 pythFeeInWei, address pythAddress, address defaultProvider, bool prefillRequestStorage, - uint256 exclusivityPeriodSeconds + uint32 exclusivityPeriodSeconds ) external initializer { require(owner != address(0), "owner is zero address"); require(admin != address(0), "admin is zero address"); diff --git a/target_chains/ethereum/contracts/forge-test/GasBenchmark.t.sol b/target_chains/ethereum/contracts/forge-test/GasBenchmark.t.sol index 3f2521938e..ee82c3280f 100644 --- a/target_chains/ethereum/contracts/forge-test/GasBenchmark.t.sol +++ b/target_chains/ethereum/contracts/forge-test/GasBenchmark.t.sol @@ -19,11 +19,12 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils { // It is possible to have more signers but the median seems to be 13. uint8 constant NUM_GUARDIAN_SIGNERS = 13; - // We use 5 prices to form a batch of 5 prices, close to our mainnet transactions. - uint8 constant NUM_PRICES = 5; + // Our mainnet transactions use 5 prices to form a batch of 5 prices, + // but Pulse allows parsing up to 10 requests in a single requests, so we set up 10 prices. + uint8 constant NUM_PRICES = 10; - // We will have less than 512 price for a foreseeable future. - uint8 constant MERKLE_TREE_DEPTH = 9; + // We will have less than 2^11=2048 price for a foreseeable future. + uint8 constant MERKLE_TREE_DEPTH = 11; IWormhole public wormhole; IPyth public pyth; @@ -265,10 +266,8 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils { 100 ); } - - function testBenchmarkParsePriceFeedUpdates1() public { - uint numIds = 1; - + // Helper function to run price feed update benchmark with a specified number of feeds + function _runParsePriceFeedUpdatesBenchmark(uint256 numIds) internal { bytes32[] memory ids = new bytes32[](numIds); for (uint i = 0; i < numIds; i++) { ids[i] = priceIds[i]; @@ -281,64 +280,44 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils { ); } - function testBenchmarkParsePriceFeedUpdates2() public { - uint numIds = 2; + function testBenchmarkParsePriceFeedUpdates1() public { + _runParsePriceFeedUpdatesBenchmark(1); + } - bytes32[] memory ids = new bytes32[](numIds); - for (uint i = 0; i < numIds; i++) { - ids[i] = priceIds[i]; - } - pyth.parsePriceFeedUpdates{value: freshPricesUpdateFee[numIds - 1]}( - freshPricesUpdateData[numIds - 1], - ids, - 0, - 50 - ); + function testBenchmarkParsePriceFeedUpdates2() public { + _runParsePriceFeedUpdatesBenchmark(2); } function testBenchmarkParsePriceFeedUpdates3() public { - uint numIds = 3; - - bytes32[] memory ids = new bytes32[](numIds); - for (uint i = 0; i < numIds; i++) { - ids[i] = priceIds[i]; - } - pyth.parsePriceFeedUpdates{value: freshPricesUpdateFee[numIds - 1]}( - freshPricesUpdateData[numIds - 1], - ids, - 0, - 50 - ); + _runParsePriceFeedUpdatesBenchmark(3); } function testBenchmarkParsePriceFeedUpdates4() public { - uint numIds = 4; - - bytes32[] memory ids = new bytes32[](numIds); - for (uint i = 0; i < numIds; i++) { - ids[i] = priceIds[i]; - } - pyth.parsePriceFeedUpdates{value: freshPricesUpdateFee[numIds - 1]}( - freshPricesUpdateData[numIds - 1], - ids, - 0, - 50 - ); + _runParsePriceFeedUpdatesBenchmark(4); } function testBenchmarkParsePriceFeedUpdates5() public { - uint numIds = 5; + _runParsePriceFeedUpdatesBenchmark(5); + } - bytes32[] memory ids = new bytes32[](numIds); - for (uint i = 0; i < numIds; i++) { - ids[i] = priceIds[i]; - } - pyth.parsePriceFeedUpdates{value: freshPricesUpdateFee[numIds - 1]}( - freshPricesUpdateData[numIds - 1], - ids, - 0, - 50 - ); + function testBenchmarkParsePriceFeedUpdates6() public { + _runParsePriceFeedUpdatesBenchmark(6); + } + + function testBenchmarkParsePriceFeedUpdates7() public { + _runParsePriceFeedUpdatesBenchmark(7); + } + + function testBenchmarkParsePriceFeedUpdates8() public { + _runParsePriceFeedUpdatesBenchmark(8); + } + + function testBenchmarkParsePriceFeedUpdates9() public { + _runParsePriceFeedUpdatesBenchmark(9); + } + + function testBenchmarkParsePriceFeedUpdates10() public { + _runParsePriceFeedUpdatesBenchmark(10); } function testBenchmarkParsePriceFeedUpdatesForAllPriceFeedsShuffledSubsetPriceIds() diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 4041076bfa..04fc3554b5 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -94,10 +94,10 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { address public pyth; address public defaultProvider; // Constants - uint128 constant PYTH_FEE = 1 wei; - uint128 constant DEFAULT_PROVIDER_FEE_PER_GAS = 1 wei; - uint128 constant DEFAULT_PROVIDER_BASE_FEE = 1 wei; - uint128 constant DEFAULT_PROVIDER_FEE_PER_FEED = 10 wei; + uint96 constant PYTH_FEE = 1 wei; + uint96 constant DEFAULT_PROVIDER_FEE_PER_GAS = 1 wei; + uint96 constant DEFAULT_PROVIDER_BASE_FEE = 1 wei; + uint96 constant DEFAULT_PROVIDER_FEE_PER_FEED = 10 wei; function setUp() public { owner = address(1); @@ -128,7 +128,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { // Helper function to calculate total fee // FIXME: I think this helper probably needs to take some arguments. - function calculateTotalFee() internal view returns (uint128) { + function calculateTotalFee() internal view returns (uint96) { return pulse.getFee(defaultProvider, CALLBACK_GAS_LIMIT, createPriceIds()); } @@ -142,26 +142,28 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { // Fund the consumer contract with enough ETH for higher gas price vm.deal(address(consumer), 1 ether); - uint128 totalFee = calculateTotalFee(); + uint96 totalFee = calculateTotalFee(); // Create the event data we expect to see + bytes8[] memory expectedPriceIdPrefixes = new bytes8[](2); + { + bytes32 priceId0 = priceIds[0]; + bytes32 priceId1 = priceIds[1]; + bytes8 prefix0; + bytes8 prefix1; + assembly { + prefix0 := priceId0 + prefix1 := priceId1 + } + expectedPriceIdPrefixes[0] = prefix0; + expectedPriceIdPrefixes[1] = prefix1; + } + PulseState.Request memory expectedRequest = PulseState.Request({ sequenceNumber: 1, publishTime: publishTime, - priceIds: [ - priceIds[0], - priceIds[1], - bytes32(0), // Fill remaining slots with zero - bytes32(0), - bytes32(0), - bytes32(0), - bytes32(0), - bytes32(0), - bytes32(0), - bytes32(0) - ], - numPriceIds: 2, - callbackGasLimit: CALLBACK_GAS_LIMIT, + priceIdPrefixes: expectedPriceIdPrefixes, + callbackGasLimit: uint32(CALLBACK_GAS_LIMIT), requester: address(consumer), provider: defaultProvider, fee: totalFee - PYTH_FEE @@ -182,9 +184,15 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { PulseState.Request memory lastRequest = pulse.getRequest(1); assertEq(lastRequest.sequenceNumber, expectedRequest.sequenceNumber); assertEq(lastRequest.publishTime, expectedRequest.publishTime); - assertEq(lastRequest.numPriceIds, expectedRequest.numPriceIds); - for (uint8 i = 0; i < lastRequest.numPriceIds; i++) { - assertEq(lastRequest.priceIds[i], expectedRequest.priceIds[i]); + assertEq( + lastRequest.priceIdPrefixes.length, + expectedRequest.priceIdPrefixes.length + ); + for (uint8 i = 0; i < lastRequest.priceIdPrefixes.length; i++) { + assertEq( + lastRequest.priceIdPrefixes[i], + expectedRequest.priceIdPrefixes[i] + ); } assertEq( lastRequest.callbackGasLimit, @@ -219,7 +227,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { // Fund the consumer contract vm.deal(address(consumer), 1 gwei); - uint128 totalFee = calculateTotalFee(); + uint96 totalFee = calculateTotalFee(); // Step 1: Make the request as consumer vm.prank(address(consumer)); @@ -404,7 +412,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { uint64 futureTime = SafeCast.toUint64(block.timestamp + 10); // 10 seconds in future vm.deal(address(consumer), 1 gwei); - uint128 totalFee = calculateTotalFee(); + uint96 totalFee = calculateTotalFee(); vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: totalFee @@ -444,7 +452,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { uint64 farFutureTime = SafeCast.toUint64(block.timestamp + 61); // Just over 1 minute vm.deal(address(consumer), 1 gwei); - uint128 totalFee = calculateTotalFee(); + uint96 totalFee = calculateTotalFee(); vm.prank(address(consumer)); vm.expectRevert("Too far in future"); @@ -491,7 +499,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { function testGetFee() public { // Test with different gas limits to verify fee calculation - uint256[] memory gasLimits = new uint256[](3); + uint32[] memory gasLimits = new uint32[](3); gasLimits[0] = 100_000; gasLimits[1] = 500_000; gasLimits[2] = 1_000_000; @@ -499,15 +507,15 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { bytes32[] memory priceIds = createPriceIds(); for (uint256 i = 0; i < gasLimits.length; i++) { - uint256 gasLimit = gasLimits[i]; - uint128 expectedFee = SafeCast.toUint128( + uint32 gasLimit = gasLimits[i]; + uint96 expectedFee = SafeCast.toUint96( DEFAULT_PROVIDER_BASE_FEE + DEFAULT_PROVIDER_FEE_PER_FEED * priceIds.length + DEFAULT_PROVIDER_FEE_PER_GAS * gasLimit ) + PYTH_FEE; - uint128 actualFee = pulse.getFee( + uint96 actualFee = pulse.getFee( defaultProvider, gasLimit, priceIds @@ -520,13 +528,13 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { } // Test with zero gas limit - uint128 expectedMinFee = SafeCast.toUint128( + uint96 expectedMinFee = SafeCast.toUint96( PYTH_FEE + DEFAULT_PROVIDER_BASE_FEE + DEFAULT_PROVIDER_FEE_PER_FEED * priceIds.length ); - uint128 actualMinFee = pulse.getFee(defaultProvider, 0, priceIds); + uint96 actualMinFee = pulse.getFee(defaultProvider, 0, priceIds); assertEq( actualMinFee, expectedMinFee, @@ -607,7 +615,10 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { uint256 managerBalanceBefore = feeManager.balance; vm.prank(feeManager); - pulse.withdrawAsFeeManager(defaultProvider, providerAccruedFees); + pulse.withdrawAsFeeManager( + defaultProvider, + uint96(providerAccruedFees) + ); assertEq( feeManager.balance, @@ -672,11 +683,17 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { // Should revert when trying to execute with wrong priceIds vm.prank(defaultProvider); + // Extract first 8 bytes of the price ID for the error expectation + bytes8 storedPriceIdPrefix; + assembly { + storedPriceIdPrefix := mload(add(priceIds, 32)) + } + vm.expectRevert( abi.encodeWithSelector( InvalidPriceIds.selector, wrongPriceIds[0], - priceIds[0] + storedPriceIdPrefix ) ); pulse.executeCallback( @@ -696,7 +713,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { } vm.deal(address(consumer), 1 gwei); - uint128 totalFee = calculateTotalFee(); + uint96 totalFee = calculateTotalFee(); vm.prank(address(consumer)); vm.expectRevert( @@ -716,7 +733,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { function testProviderRegistration() public { address provider = address(0x123); - uint128 providerFee = 1000; + uint96 providerFee = 1000; vm.prank(provider); pulse.registerProvider(providerFee, providerFee, providerFee); @@ -728,12 +745,12 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { function testSetProviderFee() public { address provider = address(0x123); - uint128 initialBaseFee = 1000; - uint128 initialFeePerFeed = 2000; - uint128 initialFeePerGas = 3000; - uint128 newFeePerFeed = 4000; - uint128 newBaseFee = 5000; - uint128 newFeePerGas = 6000; + uint96 initialBaseFee = 1000; + uint96 initialFeePerFeed = 2000; + uint96 initialFeePerGas = 3000; + uint96 newFeePerFeed = 4000; + uint96 newBaseFee = 5000; + uint96 newFeePerGas = 6000; vm.prank(provider); pulse.registerProvider( @@ -753,7 +770,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { function testDefaultProvider() public { address provider = address(0x123); - uint128 providerFee = 1000; + uint96 providerFee = 1000; vm.prank(provider); pulse.registerProvider(providerFee, providerFee, providerFee); @@ -766,7 +783,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { function testRequestWithProvider() public { address provider = address(0x123); - uint128 providerFee = 1000; + uint96 providerFee = 1000; vm.prank(provider); pulse.registerProvider(providerFee, providerFee, providerFee); @@ -1128,7 +1145,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils { defaultProvider, publishTime, priceIds, - callbackGasLimit + uint32(callbackGasLimit) ); // Complete every third request to create gaps diff --git a/target_chains/ethereum/contracts/forge-test/PulseGasBenchmark.t.sol b/target_chains/ethereum/contracts/forge-test/PulseGasBenchmark.t.sol index fcba9908cb..f82bf1d9ec 100644 --- a/target_chains/ethereum/contracts/forge-test/PulseGasBenchmark.t.sol +++ b/target_chains/ethereum/contracts/forge-test/PulseGasBenchmark.t.sol @@ -10,7 +10,7 @@ import "../contracts/pulse/PulseState.sol"; import "../contracts/pulse/PulseEvents.sol"; import "../contracts/pulse/PulseErrors.sol"; import "./utils/PulseTestUtils.t.sol"; - +import {console} from "forge-std/console.sol"; contract PulseGasBenchmark is Test, PulseTestUtils { ERC1967Proxy public proxy; PulseUpgradeable public pulse; @@ -21,10 +21,10 @@ contract PulseGasBenchmark is Test, PulseTestUtils { address public pyth; address public defaultProvider; - uint128 constant PYTH_FEE = 1 wei; - uint128 constant DEFAULT_PROVIDER_FEE_PER_GAS = 1 wei; - uint128 constant DEFAULT_PROVIDER_BASE_FEE = 1 wei; - uint128 constant DEFAULT_PROVIDER_FEE_PER_FEED = 10 wei; + uint96 constant PYTH_FEE = 1 wei; + uint96 constant DEFAULT_PROVIDER_FEE_PER_GAS = 1 wei; + uint96 constant DEFAULT_PROVIDER_BASE_FEE = 1 wei; + uint96 constant DEFAULT_PROVIDER_FEE_PER_FEED = 10 wei; function setUp() public { owner = address(1); @@ -41,7 +41,7 @@ contract PulseGasBenchmark is Test, PulseTestUtils { PYTH_FEE, pyth, defaultProvider, - false, + true, 15 ); vm.prank(defaultProvider); @@ -66,12 +66,13 @@ contract PulseGasBenchmark is Test, PulseTestUtils { createMockUpdateData(priceFeeds); } - function testBasicFlow() public { + // Helper function to run the basic request + fulfill flow with a specified number of feeds + function _runBenchmarkWithFeeds(uint256 numFeeds) internal { uint64 timestamp = SafeCast.toUint64(block.timestamp); - bytes32[] memory priceIds = createPriceIds(); + bytes32[] memory priceIds = createPriceIds(numFeeds); - uint128 callbackGasLimit = 100000; - uint128 totalFee = pulse.getFee( + uint32 callbackGasLimit = 100000; + uint96 totalFee = pulse.getFee( defaultProvider, callbackGasLimit, priceIds @@ -83,7 +84,8 @@ contract PulseGasBenchmark is Test, PulseTestUtils { }(defaultProvider, timestamp, priceIds, callbackGasLimit); PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( - timestamp + timestamp, + numFeeds ); mockParsePriceFeedUpdates(pyth, priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); @@ -95,6 +97,163 @@ contract PulseGasBenchmark is Test, PulseTestUtils { priceIds ); } + + function testBasicFlowWith01Feeds() public { + _runBenchmarkWithFeeds(1); + } + + function testBasicFlowWith02Feeds() public { + _runBenchmarkWithFeeds(2); + } + + function testBasicFlowWith04Feeds() public { + _runBenchmarkWithFeeds(4); + } + + function testBasicFlowWith08Feeds() public { + _runBenchmarkWithFeeds(8); + } + + function testBasicFlowWith10Feeds() public { + _runBenchmarkWithFeeds(10); + } + + // This test checks the gas usage for worst-case out-of-order fulfillment. + // It creates 10 requests, and then fulfills them in reverse order. + // + // The last fulfillment will be the most expensive since it needs + // to linearly scan through all the fulfilled requests in storage + // in order to update _state.lastUnfulfilledReq + // + // NOTE: Run test with `forge test --gas-report --match-test testMultipleRequestsOutOfOrderFulfillment` + // and observe the `max` value for `executeCallback` to see the cost of the most expensive request. + function testMultipleRequestsOutOfOrderFulfillment() public { + uint64 timestamp = SafeCast.toUint64(block.timestamp); + bytes32[] memory priceIds = createPriceIds(2); + uint32 callbackGasLimit = 100000; + uint128 totalFee = pulse.getFee( + defaultProvider, + callbackGasLimit, + priceIds + ); + + // Create 10 requests + uint64[] memory sequenceNumbers = new uint64[](10); + vm.deal(address(consumer), 10 ether); + + for (uint i = 0; i < 10; i++) { + vm.prank(address(consumer)); + sequenceNumbers[i] = pulse.requestPriceUpdatesWithCallback{ + value: totalFee + }( + defaultProvider, + timestamp + uint64(i), + priceIds, + callbackGasLimit + ); + } + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + timestamp + ); + mockParsePriceFeedUpdates(pyth, priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + // Execute callbacks in reverse + uint startGas = gasleft(); + for (uint i = 9; i > 0; i--) { + pulse.executeCallback( + defaultProvider, + sequenceNumbers[i], + updateData, + priceIds + ); + } + uint midGas = gasleft(); + + // Execute the first request last - this would be the most expensive + // in the original implementation as it would need to loop through + // all sequence numbers + pulse.executeCallback( + defaultProvider, + sequenceNumbers[0], + updateData, + priceIds + ); + uint endGas = gasleft(); + + // Log gas usage for the last callback which would be the most expensive + // in the original implementation (need to run test with -vv) + console.log( + "Gas used for last callback (seq 1): %s", + vm.toString(midGas - endGas) + ); + console.log( + "Gas used for all other callbacks: %s", + vm.toString(startGas - midGas) + ); + } + + // Helper function to run the overflow mapping benchmark with a specified number of feeds + function _runOverflowBenchmarkWithFeeds(uint256 numFeeds) internal { + uint64 timestamp = SafeCast.toUint64(block.timestamp); + bytes32[] memory priceIds = createPriceIds(numFeeds); + uint32 callbackGasLimit = 100000; + uint128 totalFee = pulse.getFee( + defaultProvider, + callbackGasLimit, + priceIds + ); + + // Create NUM_REQUESTS requests to fill up the main array + // The constant is defined in PulseState.sol as 32 + uint64[] memory sequenceNumbers = new uint64[](32); + vm.deal(address(consumer), 50 ether); + + // Use the same timestamp for all requests to avoid "Too far in future" error + for (uint i = 0; i < 32; i++) { + vm.prank(address(consumer)); + sequenceNumbers[i] = pulse.requestPriceUpdatesWithCallback{ + value: totalFee + }(defaultProvider, timestamp, priceIds, callbackGasLimit); + } + + // Create one more request that will go to the overflow mapping + // (This could potentially happen earlier if a shortKey collides, + // but this guarantees it.) + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: totalFee}( + defaultProvider, + timestamp, + priceIds, + callbackGasLimit + ); + } + + // These tests benchmark the gas usage when a new request overflows the fixed-size + // request array and gets stored in the overflow mapping. + // + // NOTE: Run test with `forge test --gas-report --match-test testOverflowMappingGasUsageWithXXFeeds` + // and observe the `max` value for `executeCallback` to see the cost of the overflowing request. + function testOverflowMappingGasUsageWith01Feeds() public { + _runOverflowBenchmarkWithFeeds(1); + } + + function testOverflowMappingGasUsageWith02Feeds() public { + _runOverflowBenchmarkWithFeeds(2); + } + + function testOverflowMappingGasUsageWith04Feeds() public { + _runOverflowBenchmarkWithFeeds(4); + } + + function testOverflowMappingGasUsageWith08Feeds() public { + _runOverflowBenchmarkWithFeeds(8); + } + + function testOverflowMappingGasUsageWith10Feeds() public { + _runOverflowBenchmarkWithFeeds(10); + } } // A simple consumer that does nothing with the price updates. diff --git a/target_chains/ethereum/contracts/forge-test/utils/PulseTestUtils.t.sol b/target_chains/ethereum/contracts/forge-test/utils/PulseTestUtils.t.sol index 9341cac989..eb400aa3e4 100644 --- a/target_chains/ethereum/contracts/forge-test/utils/PulseTestUtils.t.sol +++ b/target_chains/ethereum/contracts/forge-test/utils/PulseTestUtils.t.sol @@ -12,6 +12,22 @@ abstract contract PulseTestUtils is Test { 0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43; bytes32 constant ETH_PRICE_FEED_ID = 0xff61491a931112ddf1bd8147cd1b641375f79f5825126d665480874634fd0ace; + bytes32 constant SOL_PRICE_FEED_ID = + 0xef0d8b6fda2ceba41da15d4095d1da392a0d2f8ed0c6c7bc0f4cfac8c280b56d; + bytes32 constant AVAX_PRICE_FEED_ID = + 0x93da3352f9f1d105fdfe4971cfa80e9dd777bfc5d0f683ebb6e1294b92137bb7; + bytes32 constant MELANIA_PRICE_FEED_ID = + 0x8fef7d52c7f4e3a6258d663f9d27e64a1b6fd95ab5f7d545dbf9a515353d0064; + bytes32 constant PYTH_PRICE_FEED_ID = + 0x0bbf28e9a841a1cc788f6a361b17ca072d0ea3098a1e5df1c3922d06719579ff; + bytes32 constant UNI_PRICE_FEED_ID = + 0x78d185a741d07edb3412b09008b7c5cfb9bbbd7d568bf00ba737b456ba171501; + bytes32 constant AAVE_PRICE_FEED_ID = + 0x2b9ab1e972a281585084148ba1389800799bd4be63b957507db1349314e47445; + bytes32 constant DOGE_PRICE_FEED_ID = + 0xdcef50dd0a4cd2dcc17e45df1676dcb336a11a61c69df7a0299b0150c672d25c; + bytes32 constant ADA_PRICE_FEED_ID = + 0x2a01deaec9e51a579277b34b122399984d0bbf57e2458a7e42fecd2829867a0d; // Price feed constants int8 constant MOCK_PRICE_FEED_EXPO = -8; @@ -23,35 +39,72 @@ abstract contract PulseTestUtils is Test { // Fee charged by the Pyth oracle per price feed uint constant MOCK_PYTH_FEE_PER_FEED = 10 wei; - uint128 constant CALLBACK_GAS_LIMIT = 1_000_000; + uint32 constant CALLBACK_GAS_LIMIT = 1_000_000; - // Helper function to create price IDs array + // Helper function to create price IDs array with default 2 feeds function createPriceIds() internal pure returns (bytes32[] memory) { - bytes32[] memory priceIds = new bytes32[](2); - priceIds[0] = BTC_PRICE_FEED_ID; - priceIds[1] = ETH_PRICE_FEED_ID; + return createPriceIds(2); + } + + // Helper function to create price IDs array with variable number of feeds + function createPriceIds( + uint256 numFeeds + ) internal pure returns (bytes32[] memory) { + require(numFeeds <= 10, "Too many price feeds requested"); + bytes32[] memory priceIds = new bytes32[](numFeeds); + + if (numFeeds > 0) priceIds[0] = BTC_PRICE_FEED_ID; + if (numFeeds > 1) priceIds[1] = ETH_PRICE_FEED_ID; + if (numFeeds > 2) priceIds[2] = SOL_PRICE_FEED_ID; + if (numFeeds > 3) priceIds[3] = AVAX_PRICE_FEED_ID; + if (numFeeds > 4) priceIds[4] = MELANIA_PRICE_FEED_ID; + if (numFeeds > 5) priceIds[5] = PYTH_PRICE_FEED_ID; + if (numFeeds > 6) priceIds[6] = UNI_PRICE_FEED_ID; + if (numFeeds > 7) priceIds[7] = AAVE_PRICE_FEED_ID; + if (numFeeds > 8) priceIds[8] = DOGE_PRICE_FEED_ID; + if (numFeeds > 9) priceIds[9] = ADA_PRICE_FEED_ID; + return priceIds; } - // Helper function to create mock price feeds + // Helper function to create mock price feeds with default 2 feeds function createMockPriceFeeds( uint256 publishTime ) internal pure returns (PythStructs.PriceFeed[] memory) { + return createMockPriceFeeds(publishTime, 2); + } + + // Helper function to create mock price feeds with variable number of feeds + function createMockPriceFeeds( + uint256 publishTime, + uint256 numFeeds + ) internal pure returns (PythStructs.PriceFeed[] memory) { + require(numFeeds <= 10, "Too many price feeds requested"); PythStructs.PriceFeed[] memory priceFeeds = new PythStructs.PriceFeed[]( - 2 + numFeeds ); - priceFeeds[0].id = BTC_PRICE_FEED_ID; - priceFeeds[0].price.price = MOCK_BTC_PRICE; - priceFeeds[0].price.conf = MOCK_BTC_CONF; - priceFeeds[0].price.expo = MOCK_PRICE_FEED_EXPO; - priceFeeds[0].price.publishTime = publishTime; + bytes32[] memory priceIds = createPriceIds(numFeeds); + + for (uint256 i = 0; i < numFeeds; i++) { + priceFeeds[i].id = priceIds[i]; + + // Use appropriate price and conf based on the price ID + if (priceIds[i] == BTC_PRICE_FEED_ID) { + priceFeeds[i].price.price = MOCK_BTC_PRICE; + priceFeeds[i].price.conf = MOCK_BTC_CONF; + } else if (priceIds[i] == ETH_PRICE_FEED_ID) { + priceFeeds[i].price.price = MOCK_ETH_PRICE; + priceFeeds[i].price.conf = MOCK_ETH_CONF; + } else { + // Default to BTC price for other feeds + priceFeeds[i].price.price = MOCK_BTC_PRICE; + priceFeeds[i].price.conf = MOCK_BTC_CONF; + } - priceFeeds[1].id = ETH_PRICE_FEED_ID; - priceFeeds[1].price.price = MOCK_ETH_PRICE; - priceFeeds[1].price.conf = MOCK_ETH_CONF; - priceFeeds[1].price.expo = MOCK_PRICE_FEED_EXPO; - priceFeeds[1].price.publishTime = publishTime; + priceFeeds[i].price.expo = MOCK_PRICE_FEED_EXPO; + priceFeeds[i].price.publishTime = publishTime; + } return priceFeeds; } @@ -77,13 +130,17 @@ abstract contract PulseTestUtils is Test { ); } - // Helper function to create mock update data + // Helper function to create mock update data for variable feeds function createMockUpdateData( PythStructs.PriceFeed[] memory priceFeeds ) internal pure returns (bytes[] memory) { - bytes[] memory updateData = new bytes[](2); - updateData[0] = abi.encode(priceFeeds[0]); - updateData[1] = abi.encode(priceFeeds[1]); + uint256 numFeeds = priceFeeds.length; + bytes[] memory updateData = new bytes[](numFeeds); + + for (uint256 i = 0; i < numFeeds; i++) { + updateData[i] = abi.encode(priceFeeds[i]); + } + return updateData; } @@ -104,7 +161,7 @@ abstract contract PulseTestUtils is Test { publishTime = SafeCast.toUint64(block.timestamp); vm.deal(consumerAddress, 1 gwei); - uint128 totalFee = pulse.getFee(provider, CALLBACK_GAS_LIMIT, priceIds); + uint96 totalFee = pulse.getFee(provider, CALLBACK_GAS_LIMIT, priceIds); vm.prank(consumerAddress); sequenceNumber = pulse.requestPriceUpdatesWithCallback{value: totalFee}(