diff --git a/lib/hamza-escrow b/lib/hamza-escrow index 18a04d2..6d05b52 160000 --- a/lib/hamza-escrow +++ b/lib/hamza-escrow @@ -1 +1 @@ -Subproject commit 18a04d23aff663406151cf3fc5e7d1e5a0342713 +Subproject commit 6d05b52207ad80bd56421595a5ad8cd84b3c537c diff --git a/scripts/DeployHamzaVault.s.sol b/scripts/DeployHamzaVault.s.sol index 7fe618b..3484f9e 100644 --- a/scripts/DeployHamzaVault.s.sol +++ b/scripts/DeployHamzaVault.s.sol @@ -7,6 +7,7 @@ import "@baal/Baal.sol"; import "@baal/BaalSummoner.sol"; import "../src/CommunityVault.sol"; +import "../src/CommunityRewardsCalculator.sol"; import "../src/tokens/GovernanceToken.sol"; import "../src/GovernanceVault.sol"; @@ -211,6 +212,8 @@ contract DeployHamzaVault is Script { // 5) Deploy the Community Vault CommunityVault communityVault = new CommunityVault(hatsSecurityContextAddr); vault = payable(address(communityVault)); + CommunityRewardsCalculator calculator = new CommunityRewardsCalculator(); + communityVault.setCommunityRewardsCalculator(calculator); // 6) Summon the Baal DAO BaalSummoner summoner = BaalSummoner(BAAL_SUMMONER); @@ -412,10 +415,10 @@ contract DeployHamzaVault is Script { bool autoRelease = config.readBool(".escrow.autoRelease"); // 15) Deploy PurchaseTracker - PurchaseTracker purchaseTracker = new PurchaseTracker(securityContext, vault, lootTokenAddr); + PurchaseTracker purchaseTracker = new PurchaseTracker(securityContext, lootTokenAddr); //setPurchaseTracker in community vault - CommunityVault(vault).setPurchaseTracker(address(purchaseTracker), lootTokenAddr); + CommunityVault(vault).setPurchaseTracker(address(purchaseTracker)); purchaseTrackerAddr = address(purchaseTracker); @@ -438,7 +441,7 @@ contract DeployHamzaVault is Script { function logDeployedAddresses( address newBaalAddr, - address vault, + address communityVault, address govTokenAddr, address govVaultAddr, address timelockAddr @@ -446,7 +449,7 @@ contract DeployHamzaVault is Script { console2.log("Owner One (from PRIVATE_KEY):", OWNER_ONE); console2.log("Owner Two (from config): ", OWNER_TWO); - console2.log("CommunityVault deployed at:", vault); + console2.log("CommunityVault deployed at:", communityVault); console2.log("BaalSummoner at:", BAAL_SUMMONER); console2.log("Baal (Hamza Vault) deployed at:", newBaalAddr); diff --git a/src/CommunityRewardsCalculator.sol b/src/CommunityRewardsCalculator.sol new file mode 100644 index 0000000..0b907a8 --- /dev/null +++ b/src/CommunityRewardsCalculator.sol @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.20; + +import "@hamza-escrow/IPurchaseTracker.sol"; +import "./ICommunityRewardsCalculator.sol"; + +/** + * @title CommunityRewardsCalculator + * @dev Contains the logic for calculating who gets what rewards, and for what reasons. The rewards are + * distributed through the CommunityVault. + */ +contract CommunityRewardsCalculator is ICommunityRewardsCalculator { + + function getRewardsToDistribute( + address token, + address[] calldata recipients, + IPurchaseTracker purchaseTracker + ) external returns (uint256[] memory) { + uint256[] memory amounts = new uint256[](recipients.length); + + // for every purchase or sale made by the recipient, distribute 1 loot token + for (uint i=0; i uint256) public tokenBalances; + // Keeps a count of already distributed rewards + mapping(address => mapping(address => uint256)) public rewardsDistributed; + // Governance staking contract address address public governanceVault; // Address for purchase tracker address public purchaseTracker; + // Address for rewards calculator + ICommunityRewardsCalculator public rewardsCalculator; + // Events event Deposit(address indexed token, address indexed from, uint256 amount); event Withdraw(address indexed token, address indexed to, uint256 amount); @@ -91,35 +97,36 @@ contract CommunityVault is HasSecurityContext { address[] calldata recipients, uint256[] calldata amounts ) external onlyRole(Roles.SYSTEM_ROLE) { - require(recipients.length == amounts.length, "Mismatched arrays"); - - for (uint256 i = 0; i < recipients.length; i++) { - require(tokenBalances[token] >= amounts[i], "Insufficient balance"); - - if (token == address(0)) { - // ETH distribution - (bool success, ) = recipients[i].call{value: amounts[i]}(""); - require(success, "ETH transfer failed"); - } else { - // ERC20 distribution - IERC20(token).safeTransfer(recipients[i], amounts[i]); - } + _distribute(token, recipients, amounts); + } - tokenBalances[token] -= amounts[i]; + /** + * @dev Distribute tokens or ETH from the community vault to multiple recipients, using the + * CommunityRewardsCalculator to calculate the amounts to reward each recipient. + * @param token The address of the token + * @param recipients The array of recipient addresses + */ + function distributeRewards(address token, address[] memory recipients) external onlyRole(Roles.SYSTEM_ROLE) { + _distributeRewards(token, recipients); + } - emit Distribute(token, recipients[i], amounts[i]); - } + /** + * @dev Allows a rightful recipient of rewards to claim rewards that have been allocated to them. + * @param token The address of the token + */ + function claimRewards(address token) external { + address[] memory recipients = new address[](1); + recipients[0] = msg.sender; + _distributeRewards(token, recipients); } /** - * @dev Set the governance vault address and grant it unlimited allowance for `lootToken`. - * Must be called by an admin role or similar. - * @param vault The address of the governance vault - * @param lootToken The address of the ERC20 token for which you'd like to grant unlimited allowance - */ - function setGovernanceVault(address vault, address lootToken) - external - { + * @dev Set the governance vault address and grant it unlimited allowance for `lootToken`. + * Must be called by an admin role or similar. + * @param vault The address of the governance vault + * @param lootToken The address of the ERC20 token for which you'd like to grant unlimited allowance + */ + function setGovernanceVault(address vault, address lootToken) external onlyRole(Roles.SYSTEM_ROLE) { require(vault != address(0), "Invalid staking contract address"); require(lootToken != address(0), "Invalid loot token address"); @@ -130,15 +137,20 @@ contract CommunityVault is HasSecurityContext { IERC20(lootToken).safeApprove(vault, type(uint256).max); } - function setPurchaseTracker(address _purchaseTracker, address lootToken) external { + /** + * @dev Sets the purchase tracker that is used to keep track of who has done what, in order to get rewards. + */ + function setPurchaseTracker(address _purchaseTracker) external onlyRole(Roles.SYSTEM_ROLE) { require(_purchaseTracker != address(0), "Invalid purchase tracker address"); - require(lootToken != address(0), "Invalid loot token address"); purchaseTracker = _purchaseTracker; + } - // Grant unlimited allowance to the purchase tracker - IERC20(lootToken).safeApprove(_purchaseTracker, 0); - IERC20(lootToken).safeApprove(_purchaseTracker, type(uint256).max); + /** + * @dev Sets the address of the contract which holds the logic for calculating how to divide up rewards. + */ + function setCommunityRewardsCalculator(ICommunityRewardsCalculator calculator) external onlyRole(Roles.SYSTEM_ROLE) { + rewardsCalculator = calculator; } /** @@ -149,6 +161,47 @@ contract CommunityVault is HasSecurityContext { return tokenBalances[token]; } + function _distributeRewards(address token, address[] memory recipients) internal { + if (address(rewardsCalculator) != address(0) && address(purchaseTracker) != address(0)) { + + //get rewards to distribute + uint256[] memory amounts = rewardsCalculator.getRewardsToDistribute( + token, recipients, IPurchaseTracker(purchaseTracker) + ); + + _distribute(token, recipients, amounts); + } + } + + function _distribute( + address token, + address[] memory recipients, + uint256[] memory amounts + ) internal { + require(recipients.length == amounts.length, "Mismatched arrays"); + + for (uint256 i = 0; i < recipients.length; i++) { + require(tokenBalances[token] >= amounts[i], "Insufficient balance"); + + if (token == address(0)) { + // ETH distribution + (bool success, ) = recipients[i].call{value: amounts[i]}(""); + require(success, "ETH transfer failed"); + } else { + // ERC20 distribution + IERC20(token).safeTransfer(recipients[i], amounts[i]); + } + + // decrement balance + tokenBalances[token] -= amounts[i]; + + // record the distribution + rewardsDistributed[token][recipients[i]] += amounts[i]; + + emit Distribute(token, recipients[i], amounts[i]); + } + } + // Fallback to receive ETH receive() external payable {} } diff --git a/src/ICommunityRewardsCalculator.sol b/src/ICommunityRewardsCalculator.sol new file mode 100644 index 0000000..bc7102b --- /dev/null +++ b/src/ICommunityRewardsCalculator.sol @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.20; + +import "@hamza-escrow/IPurchaseTracker.sol"; + +/** + * @title ICommunityRewardsCalculator + * @dev Defines the logic for calculating who gets what rewards, and for what reasons. The rewards are + * distributed through the CommunityVault. + */ +interface ICommunityRewardsCalculator { + function getRewardsToDistribute( + address token, + address[] calldata recipients, + IPurchaseTracker purchaseTracker + ) external returns (uint256[] memory); +} diff --git a/src/PurchaseTracker.sol b/src/PurchaseTracker.sol index 78ee9fa..8a4f167 100644 --- a/src/PurchaseTracker.sol +++ b/src/PurchaseTracker.sol @@ -4,13 +4,14 @@ pragma solidity ^0.8.20; import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; import "@hamza-escrow/security/HasSecurityContext.sol"; +import "@hamza-escrow/IPurchaseTracker.sol"; /** * @title PurchaseTracker * @notice A singleton contract that records purchase data. * */ -contract PurchaseTracker is HasSecurityContext { +contract PurchaseTracker is HasSecurityContext, IPurchaseTracker { using SafeERC20 for IERC20; // Mapping from buyer address to cumulative purchase count and total purchase amount. @@ -20,16 +21,10 @@ contract PurchaseTracker is HasSecurityContext { // mapping for sellers mapping(address => uint256) public totalSalesCount; mapping(address => uint256) public totalSalesAmount; - - // mapping rewards distributed - mapping(address => uint256) public rewardsDistributed; // Store details about each purchase (keyed by the unique payment ID). mapping(bytes32 => Purchase) public purchases; - //Comunity Vault address - address public communityVault; - // loot token IERC20 public lootToken; @@ -50,8 +45,7 @@ contract PurchaseTracker is HasSecurityContext { _; } - constructor(ISecurityContext securityContext, address _communityVault, address _lootToken) { - communityVault = _communityVault; + constructor(ISecurityContext securityContext, address _lootToken) { lootToken = IERC20(_lootToken); _setSecurityContext(securityContext); } @@ -96,23 +90,20 @@ contract PurchaseTracker is HasSecurityContext { emit PurchaseRecorded(paymentId, buyer, amount); } - - // distrubte reward from communtiy vault - function distributeReward(address recipient) external { - // for every purchase or sale made by the recipient, distribute 1 loot token - uint256 totalPurchase = totalPurchaseCount[recipient]; - uint256 totalSales = totalSalesCount[recipient]; - uint256 rewardsDist = rewardsDistributed[recipient]; - - uint256 totalRewards = totalPurchase + totalSales - rewardsDist; - require(totalRewards > 0, "PurchaseTracker: No rewards to distribute"); + function getPurchaseCount(address recipient) external view returns (uint256) { + return totalPurchaseCount[recipient]; + } - // transfer loot token from community vault to recipient - lootToken.safeTransferFrom(communityVault, recipient, totalRewards); + function getPurchaseAmount(address recipient) external view returns (uint256) { + return totalPurchaseAmount[recipient]; + } - rewardsDistributed[recipient] += totalRewards; + function getSalesCount(address recipient) external view returns (uint256) { + return totalSalesCount[recipient]; } - + function getSalesAmount(address recipient) external view returns (uint256) { + return totalSalesAmount[recipient]; + } } diff --git a/test/CommunityVault.t.sol b/test/CommunityVault.t.sol new file mode 100644 index 0000000..1cd6036 --- /dev/null +++ b/test/CommunityVault.t.sol @@ -0,0 +1,556 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.17; + +import "forge-std/Test.sol"; +import "./DeploymentSetup.t.sol"; +import "../src/PurchaseTracker.sol"; +import "@hamza-escrow/PaymentEscrow.sol" as EscrowLib; +import "@hamza-escrow/security/Roles.sol"; +import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import "@hamza-escrow/ISystemSettings.sol"; +import "@hamza-escrow/TestToken.sol"; + +/** + * @notice This test suite tests that the CommunityVault behaves as expected. + */ +contract TestCommunityVault is DeploymentSetup { + CommunityVault internal vault; + PurchaseTracker internal tracker; + TestToken internal testToken; + + address internal depositor; + address internal recipient; + + address internal depositor1; + address internal recipient1; + + address internal depositor2; + address internal recipient2; + + address internal depositor3; + address internal recipient3; + + uint256 internal constant INITIAL_USER_BALANCE = 100_000_000_000 ether; + IERC20 internal loot; + + function setUp() public virtual override { + super.setUp(); + + // Cast addresses into actual contract instances + testToken = new TestToken("ABC", "123"); + tracker = PurchaseTracker(purchaseTracker); + vault = CommunityVault(communityVault); + loot = IERC20(lootToken); + + // Define test addresses + depositor1 = makeAddr("depositor1"); + recipient1 = makeAddr("recipient1"); + + depositor2 = makeAddr("depositor2"); + recipient2 = makeAddr("recipient2"); + + depositor3 = makeAddr("depositor3"); + recipient3 = makeAddr("recipient3"); + + depositor = depositor1; + recipient = recipient1; + + // Fund the depositor with both ETH and ERC20 tokens + vm.deal(depositor1, INITIAL_USER_BALANCE); + vm.deal(depositor2, INITIAL_USER_BALANCE); + vm.deal(depositor3, INITIAL_USER_BALANCE); + deal(address(loot), depositor1, INITIAL_USER_BALANCE); + deal(address(loot), depositor2, INITIAL_USER_BALANCE); + deal(address(loot), depositor3, INITIAL_USER_BALANCE); + deal(address(testToken), depositor1, INITIAL_USER_BALANCE); + deal(address(testToken), depositor2, INITIAL_USER_BALANCE); + deal(address(testToken), depositor3, INITIAL_USER_BALANCE); + } + + // Test that getBalance gets the right balance + function testGetCorrectBalance() public { + uint256 depositAmount1 = 1000; + uint256 depositAmount2 = 1200; + uint256 depositAmount3 = 10040; + uint256 withdrawAmount1 = 100; + uint256 withdrawAmount2 = 120; + uint256 withdrawAmount3 = 1400; + + deposit(depositor1, IERC20(loot), depositAmount1); + assertEq(vault.getBalance(address(loot)), depositAmount1); + + deposit(depositor2, IERC20(loot), depositAmount2); + assertEq(vault.getBalance(address(loot)), depositAmount1 + depositAmount2); + + deposit(depositor3, IERC20(testToken), depositAmount3); + assertEq(vault.getBalance(address(loot)), depositAmount1 + depositAmount2); + assertEq(vault.getBalance(address(testToken)), depositAmount3); + + vm.prank(admin); + vault.withdraw(address(loot), recipient1, depositAmount1); + assertEq(vault.getBalance(address(loot)), depositAmount2); + assertEq(vault.getBalance(address(testToken)), depositAmount3); + + vm.prank(admin); + vault.withdraw(address(loot), recipient1, depositAmount2); + assertEq(vault.getBalance(address(loot)), 0); + assertEq(vault.getBalance(address(testToken)), depositAmount3); + + vm.prank(admin); + vault.withdraw(address(testToken), recipient1, depositAmount3); + assertEq(vault.getBalance(address(loot)), 0); + assertEq(vault.getBalance(address(testToken)), 0); + } + + // Test that deposit emits the Deposit event + function testDepositEmitsEvent() public { + uint256 depositAmount = 1000; + + vm.startPrank(depositor1); + loot.approve(address(vault), depositAmount); + vm.expectEmit(false, false, false, false); + emit CommunityVault.Deposit(address(loot), depositor1, depositAmount); + vault.deposit(address(loot), depositAmount); + vm.stopPrank(); + } + + // Test that deposit Incorrect ETH amount error + function testDepositIncorrectAmount() public { + vm.startPrank(depositor1); + vm.expectRevert("Incorrect ETH amount"); + vault.deposit(address(0), 1000); + vm.stopPrank(); + } + + // Test that deposit transfers token in the correct way + function testDepositTransfersTokens() public { + uint256 depositAmount1 = 1000; + uint256 depositAmount2 = 1100; + uint256 depositAmount3 = 1200; + + //initial balances + uint256 initialDepositor1Balance = loot.balanceOf(depositor1); + uint256 initialDepositor2Balance = loot.balanceOf(depositor2); + uint256 initialVaultBalance = loot.balanceOf(address(vault)); + + //first deposit + deposit(depositor1, IERC20(loot), depositAmount1); + + //check balances + assertEq(loot.balanceOf(depositor1), (initialDepositor1Balance - depositAmount1)); + assertEq(loot.balanceOf(address(vault)), (initialVaultBalance + depositAmount1)); + + //second deposit + deposit(depositor1, IERC20(loot), depositAmount2); + + //check balances + assertEq(loot.balanceOf(depositor1), (initialDepositor1Balance - depositAmount1 - depositAmount2)); + assertEq(loot.balanceOf(address(vault)), (initialVaultBalance + depositAmount1 + depositAmount2)); + + //third deposit + deposit(depositor2, IERC20(loot), depositAmount3); + + //check balances + assertEq(loot.balanceOf(depositor2), (initialDepositor2Balance - depositAmount3)); + assertEq(loot.balanceOf(address(vault)), (initialVaultBalance + depositAmount1 + depositAmount2 + depositAmount3)); + } + + // Test that deposit transfers ETH in the correct way + function testDepositTransfersEth() public { + uint256 depositAmount1 = 1000; + uint256 depositAmount2 = 1100; + uint256 depositAmount3 = 1200; + + //initial balances + uint256 initialDepositor1Balance = depositor1.balance; + uint256 initialDepositor2Balance = depositor2.balance; + uint256 initialVaultBalance = address(vault).balance; + + assertNotEq(initialDepositor1Balance, 0); + assertNotEq(initialDepositor2Balance, 0); + + //first deposit + deposit(depositor1, IERC20(address(0)), depositAmount1); + + //check balances + assertEq(depositor1.balance, (initialDepositor1Balance - depositAmount1)); + assertEq(address(vault).balance, (initialVaultBalance + depositAmount1)); + + //second deposit + deposit(depositor1, IERC20(address(0)), depositAmount2); + + //check balances + assertEq(depositor1.balance, (initialDepositor1Balance - depositAmount1 - depositAmount2)); + assertEq(address(vault).balance, (initialVaultBalance + depositAmount1 + depositAmount2)); + + //third deposit + deposit(depositor2, IERC20(address(0)), depositAmount3); + + //check balances + assertEq(depositor2.balance, (initialDepositor2Balance - depositAmount3)); + assertEq(address(vault).balance, (initialVaultBalance + depositAmount1 + depositAmount2 + depositAmount3)); //, 0.1e16); + assertEq(vault.getBalance(address(0)), address(vault).balance); + } + + // Test that deposit behaves correctly when balance too low + function testDepositOverLimit() public { + //TODO: testDepositOverLimit + } + + // Test that withdraw Insufficient Balance error + function testWithdrawInsufficientBalanceError() public { + uint256 depositAmount = 1000; + + deposit(depositor1, IERC20(loot), depositAmount); + + vm.startPrank(admin); + vm.expectRevert(); + vault.withdraw(address(loot), recipient1, depositAmount + 1); + vm.stopPrank(); + } + + // Test that withdraw transfers token correctly + function testWithdrawTransfersTokens() public { + //TODO: testWithdrawTransfersTokens + } + + // Test that withdraw emits Withdraw event + function testWithdrawEmitsEvent() public { + //TODO: testWithdrawEmitsEvent + } + + // Test that withdraw behaves correctly when balance too low + function testWithdrawOverLimit() public { + //TODO: testWithdrawOverLimit + } + + // Test that withdraw is only callable if authorized + function testWithdrawRestricted() public { + uint256 depositAmount = 100; + deposit(depositor1, IERC20(loot), depositAmount); + + vm.expectRevert(); + vault.withdraw(address(loot), depositor2, depositAmount); + } + + // Test that distribute is only callable if authorized + function testDistributeRestricted() public { + uint256 depositAmount = 100; + deposit(depositor1, IERC20(loot), depositAmount); + + address[] memory recipients = new address[](1); + uint256[] memory amounts = new uint256[](1); + recipients[0] = recipient; + amounts[0] = 50; + + vm.expectRevert(); + vault.distribute(address(loot), recipients, amounts); + } + + // Test that distribute "Mismatched arrays" error + function testDistributeMismatchedArraysError() public { + uint256 depositAmount = 100; + deposit(depositor1, IERC20(loot), depositAmount); + + address[] memory recipients = new address[](1); + uint256[] memory amounts = new uint256[](2); + recipients[0] = recipient1; + amounts[0] = 50; + amounts[1] = 10; + + vm.startPrank(admin); + vm.expectRevert("Mismatched arrays"); + vault.distribute(address(loot), recipients, amounts); + vm.stopPrank(); + } + + // Test that distribute adjusts balances & distributes token correctly + //TODO: finish testDistributeRewards + function testDistributeRewards() public { + uint256 depositLootAmount1 = 1000; + uint256 depositLootAmount2 = 1100; + uint256 depositLootAmount3 = 1200; + uint256 depositLootTotal = depositLootAmount1+depositLootAmount2+depositLootAmount3; + uint256 depositEthAmount1 = 3000; + uint256 depositEthAmount2 = 2100; + uint256 depositEthAmount3 = 4200; + uint256 depositEthTotal = depositEthAmount1+depositEthAmount2+depositEthAmount3; + + //deposit loot + deposit(depositor1, IERC20(loot), depositLootAmount1); + deposit(depositor1, IERC20(loot), depositLootAmount2); + deposit(depositor2, IERC20(loot), depositLootAmount3); + + //deposit eth + deposit(depositor1, IERC20(address(0)), depositLootAmount1); + deposit(depositor1, IERC20(address(0)), depositLootAmount2); + deposit(depositor2, IERC20(address(0)), depositLootAmount3); + + //prepare to distribute + address[] memory recipients = new address[](3); + recipients[0] = recipient1; + recipients[1] = recipient2; + recipients[2] = recipient3; + + //prepare loot amounts to distribute + uint256 distributeLootAmount1 = (depositLootTotal/3)-1; + uint256 distributeLootAmount2 = (depositLootTotal/3)-2; + uint256 distributeLootAmount3 = (depositLootTotal/3)-3; + + //prepare eth amounts to distribute + uint256 distributeEthAmount1 = (depositEthTotal/10); + uint256 distributeEthAmount2 = (depositEthTotal/10); + uint256 distributeEthAmount3 = (depositEthTotal/10); + + uint256[] memory amounts = new uint256[](3); + amounts[0] = distributeLootAmount1; + amounts[1] = distributeLootAmount2; + amounts[2] = distributeLootAmount3; + + //distribute loot + vm.prank(admin); + vault.distribute(address(loot), recipients, amounts); + + amounts[0] = distributeEthAmount1; + amounts[1] = distributeEthAmount2; + amounts[2] = distributeEthAmount3; + + //distribute eth + vm.prank(admin); + vault.distribute(address(0), recipients, amounts); + } + + // Test that distribute handles insufficient balances correctly + function testDistributeInsufficientBalance() public { + //TODO: testDistributeInsufficientBalance + } + + // Test that distribute emits Distribute event + function testDistributeEmitsEvent() public { + uint256 depositAmount = 100; + deposit(depositor1, IERC20(loot), depositAmount); + + address[] memory recipients = new address[](1); + uint256[] memory amounts = new uint256[](1); + recipients[0] = recipient1; + amounts[0] = 50; + + vm.startPrank(admin); + vm.expectEmit(false, false, false, false); + emit CommunityVault.Distribute(address(loot), recipients[0], amounts[0]); + vault.distribute(address(loot), recipients, amounts); + vm.stopPrank(); + } + + // Test that setGovernanceVault can be only called by admin + function testSetGovernanceVaultRestricted() public { + vm.expectRevert(); + vault.setGovernanceVault(govVault, address(loot)); + } + + // Test that setGovernanceVault sets the GovernanceVault + function testSetGovernanceVault() public { + GovernanceVault newGovVault = new GovernanceVault( + ISecurityContext(hatsCtx), + address(testToken), + GovernanceToken(address(govToken)), + 100 + ); + + vm.prank(admin); + vault.setGovernanceVault(address(newGovVault), address(testToken)); + + assertEq(address(vault.governanceVault()), address(newGovVault)); + assertEq(testToken.allowance(address(vault), address(vault.governanceVault())), type(uint256).max); + } + + // Test that setGovernanceVault validates address arguments + function testSetGovernanceVaultAddressZero() public { + vm.startPrank(admin); + + //invalid governance vault + vm.expectRevert("Invalid staking contract address"); + vault.setGovernanceVault(address(0), address(loot)); + + //invalid loot token + vm.expectRevert("Invalid loot token address"); + vault.setGovernanceVault(govVault, address(0)); + + vm.stopPrank(); + } + + // Test that setPurchaseTracker can be only called by admin + function testSetPurchaseTrackerRestricted() public { + vm.expectRevert(); + vault.setPurchaseTracker(purchaseTracker); + } + + // Test that setCommunityRewardsCalculator sets the CommunityRewardsCalculator + function testSetCommunityRewardsCalculator() public { + CommunityRewardsCalculator newCalc = new CommunityRewardsCalculator(); + + vm.prank(admin); + vault.setCommunityRewardsCalculator(ICommunityRewardsCalculator(newCalc)); + + assertEq(address(vault.rewardsCalculator()), address(newCalc)); + } + + // Test that setPurchaseTracker can be only called by admin + function testSetCommunityRewardsCalculatorRestricted() public { + CommunityRewardsCalculator calc = new CommunityRewardsCalculator(); + vm.expectRevert(); + vault.setCommunityRewardsCalculator(calc); + } + + // Test that setPurchaseTracker sets the PurchaseTracker + function testSetPurchaseTracker() public { + PurchaseTracker newTracker = new PurchaseTracker( + ISecurityContext(hatsCtx), address(testToken) + ); + + vm.prank(admin); + vault.setPurchaseTracker(address(newTracker)); + + assertEq(address(vault.purchaseTracker()), address(newTracker)); + } + + // Test that setPurchaseTracker validates address arguments + function testSetPurchaseTrackerAddressZero() public { + vm.startPrank(admin); + + //invalid purchase tracker + vm.expectRevert("Invalid purchase tracker address"); + vault.setPurchaseTracker(address(0)); + + vm.stopPrank(); + } + + // Test that rewards are distributed through the PurchaseTracker and CommunityRewardsCalculator to buyers + function testDistributeRewardsForBuyer() public { + bytes32 paymentId = keccak256("payment-reward-test-1"); + uint256 payAmount = 500; + + address payer = depositor1; + address seller = recipient1; + + PaymentEscrow payEscrow = PaymentEscrow(payable(escrow)); + + //make sure there's enough in the vault to distribute + deposit(depositor2, loot, 100_000); + + // Buyer makes a purchase + vm.startPrank(payer); + loot.approve(address(payEscrow), payAmount); + + PaymentInput memory input = PaymentInput({ + id: paymentId, + payer: payer, + receiver: seller, + currency: address(loot), + amount: payAmount + }); + payEscrow.placePayment(input); + vm.stopPrank(); + + // Buyer and seller release escrow + vm.prank(depositor1); + payEscrow.releaseEscrow(paymentId); + + if (!autoRelease) { + vm.prank(seller); + payEscrow.releaseEscrow(paymentId); + } + + // Validate purchase tracking + assertEq(tracker.totalPurchaseCount(payer), 1, "Incorrect purchase count"); + assertEq(tracker.totalSalesCount(seller), 1, "Incorrect sales count"); + + // Check initial reward balance + uint256 initialBuyerBalance = loot.balanceOf(payer); + uint256 rewardsToDistribute = tracker.totalPurchaseCount(payer); + + // Distribute reward + address[] memory recipients = new address[](1); + recipients[0] = payer; + + vm.prank(admin); + CommunityVault(communityVault).distributeRewards(address(loot), recipients); + + // Verify rewards were distributed + assertEq(loot.balanceOf(payer), initialBuyerBalance + rewardsToDistribute, "Incorrect reward distribution"); + assertEq(vault.rewardsDistributed(address(loot), payer), rewardsToDistribute, "Incorrect rewards tracked"); + } + + // Test that rewards are distributed through the PurchaseTracker and CommunityRewardsCalculator to sellers + function testDistributeRewardsForSeller() public { + bytes32 paymentId = keccak256("payment-reward-test-2"); + uint256 payAmount = 750_000_000_000_000; + + address payer = depositor1; + address seller = recipient1; + + PaymentEscrow payEscrow = PaymentEscrow(payable(escrow)); + + //make sure there's enough in the vault to distribute + deposit(depositor2, loot, 100_000); + + // Buyer makes a purchase + vm.startPrank(payer); + loot.approve(address(payEscrow), payAmount); + + PaymentInput memory input = PaymentInput({ + id: paymentId, + payer: payer, + receiver: seller, + currency: address(loot), + amount: payAmount + }); + payEscrow.placePayment(input); + vm.stopPrank(); + + // Buyer and seller release escrow + vm.prank(payer); + payEscrow.releaseEscrow(paymentId); + + if (!autoRelease) { + vm.prank(seller); + payEscrow.releaseEscrow(paymentId); + } + + // Validate purchase tracking + assertEq(tracker.totalSalesCount(seller), 1, "Incorrect sales count"); + + // Check initial reward balance + uint256 initialSellerBalance = loot.balanceOf(seller); + uint256 rewardsToDistribute = tracker.totalSalesCount(seller); + + assertEq(rewardsToDistribute, 1); + return; + + // Distribute rewards + address[] memory recipients = new address[](1); + recipients[0] = seller; + + vm.prank(admin); + CommunityVault(communityVault).distributeRewards(address(loot), recipients); + + // Verify rewards were distributed + assertEq(loot.balanceOf(seller), initialSellerBalance + rewardsToDistribute, "Incorrect reward distribution"); + assertEq(vault.rewardsDistributed(address(loot), seller), rewardsToDistribute, "Incorrect rewards tracked"); + } + + function deposit(address _depositor, IERC20 token, uint256 amount) private { + vm.startPrank(_depositor); + if (address(token) != address(0)) { + token.approve(address(vault), amount); + vault.deposit(address(token), amount); + } + else { + address(vault).call{value: amount}( + abi.encodeWithSignature("deposit(address,uint256)", address(token), amount) + ); + } + + vm.stopPrank(); + } +} diff --git a/test/PurchaseTracker.t.sol b/test/PurchaseTracker.t.sol index 55babe2..ca6340c 100644 --- a/test/PurchaseTracker.t.sol +++ b/test/PurchaseTracker.t.sol @@ -325,102 +325,9 @@ contract TestPaymentAndTracker is DeploymentSetup { assertEq(tracker.totalSalesAmount(seller), netAmount, "Seller sales total mismatch"); } - function testDistributeRewardsForBuyer() public { - bytes32 paymentId = keccak256("payment-reward-test-1"); - uint256 payAmount = 500; + //TODO: TEST: test that PurchaseRecorded event is emitted + //TODO: TEST: test 'PurchaseTracker: Purchase already recorded' error - // Buyer makes a purchase - vm.startPrank(payer); - loot.approve(address(payEscrow), payAmount); - - PaymentInput memory input = PaymentInput({ - id: paymentId, - payer: payer, - receiver: seller, - currency: address(loot), - amount: payAmount - }); - payEscrow.placePayment(input); - vm.stopPrank(); - - // Buyer and seller release escrow - vm.prank(payer); - payEscrow.releaseEscrow(paymentId); - - if (!autoRelease) { - vm.prank(seller); - payEscrow.releaseEscrow(paymentId); - } - - // Validate purchase tracking - assertEq(tracker.totalPurchaseCount(payer), 1, "Incorrect purchase count"); - assertEq(tracker.totalSalesCount(seller), 1, "Incorrect sales count"); - - // Check initial reward balance - uint256 initialBuyerBalance = loot.balanceOf(payer); - uint256 rewardsToDistribute = tracker.totalPurchaseCount(payer); - - // Distribute reward - vm.prank(payer); - tracker.distributeReward(payer); - - // Verify rewards were distributed - assertEq(loot.balanceOf(payer), initialBuyerBalance + rewardsToDistribute, "Incorrect reward distribution"); - assertEq(tracker.rewardsDistributed(payer), rewardsToDistribute, "Incorrect rewards tracked"); - } - - function testDistributeRewardsForSeller() public { - bytes32 paymentId = keccak256("payment-reward-test-2"); - uint256 payAmount = 750; - - // Buyer makes a purchase - vm.startPrank(payer); - loot.approve(address(payEscrow), payAmount); - - PaymentInput memory input = PaymentInput({ - id: paymentId, - payer: payer, - receiver: seller, - currency: address(loot), - amount: payAmount - }); - payEscrow.placePayment(input); - vm.stopPrank(); - - // Buyer and seller release escrow - vm.prank(payer); - payEscrow.releaseEscrow(paymentId); - - if (!autoRelease) { - vm.prank(seller); - payEscrow.releaseEscrow(paymentId); - } - - // Validate purchase tracking - assertEq(tracker.totalSalesCount(seller), 1, "Incorrect sales count"); - - // Check initial reward balance - uint256 initialSellerBalance = loot.balanceOf(seller); - uint256 rewardsToDistribute = tracker.totalSalesCount(seller); - - // Distribute reward - vm.prank(seller); - tracker.distributeReward(seller); - - // Verify rewards were distributed - assertEq(loot.balanceOf(seller), initialSellerBalance + rewardsToDistribute, "Incorrect reward distribution"); - assertEq(tracker.rewardsDistributed(seller), rewardsToDistribute, "Incorrect rewards tracked"); - } - - function testDistributeRewardsFailsIfNoPurchasesOrSales() public { - // Check initial reward distribution - assertEq(tracker.rewardsDistributed(arbiter), 0, "Arbiter should have no rewards"); - - // Expect revert due to no rewards available - vm.expectRevert("PurchaseTracker: No rewards to distribute"); - vm.prank(arbiter); - tracker.distributeReward(arbiter); - } function payEscrowSettingsFee() internal view returns (uint256) { return systemSettings1.feeBps();