diff --git a/src/CommunityRewardsCalculator.sol b/src/CommunityRewardsCalculator.sol index 0b907a8..8f3e69f 100644 --- a/src/CommunityRewardsCalculator.sol +++ b/src/CommunityRewardsCalculator.sol @@ -12,16 +12,16 @@ import "./ICommunityRewardsCalculator.sol"; contract CommunityRewardsCalculator is ICommunityRewardsCalculator { function getRewardsToDistribute( - address token, + address /*token*/, address[] calldata recipients, IPurchaseTracker purchaseTracker - ) external returns (uint256[] memory) { + ) external view returns (uint256[] memory) { uint256[] memory amounts = new uint256[](recipients.length); // for every purchase or sale made by the recipient, distribute 1 loot token for (uint i=0; i uint256) public tokenBalances; - // Keeps a count of already distributed rewards mapping(address => mapping(address => uint256)) public rewardsDistributed; @@ -58,8 +55,6 @@ contract CommunityVault is HasSecurityContext { IERC20(token).safeTransferFrom(msg.sender, address(this), amount); } - tokenBalances[token] += amount; - emit Deposit(token, msg.sender, amount); } @@ -70,7 +65,7 @@ contract CommunityVault is HasSecurityContext { * @param amount The amount to withdraw */ function withdraw(address token, address to, uint256 amount) external onlyRole(Roles.SYSTEM_ROLE) { - require(tokenBalances[token] >= amount, "Insufficient balance"); + require(this.getBalance(token) >= amount, "Insufficient balance"); if (token == address(0)) { // ETH withdrawal @@ -81,8 +76,6 @@ contract CommunityVault is HasSecurityContext { IERC20(token).safeTransfer(to, amount); } - tokenBalances[token] -= amount; - emit Withdraw(token, to, amount); } @@ -158,7 +151,8 @@ contract CommunityVault is HasSecurityContext { * @param token The address of the token */ function getBalance(address token) external view returns (uint256) { - return tokenBalances[token]; + if (token == address(0)) return (address(this)).balance; + return IERC20(token).balanceOf(address(this)); } function _distributeRewards(address token, address[] memory recipients) internal { @@ -181,7 +175,7 @@ contract CommunityVault is HasSecurityContext { require(recipients.length == amounts.length, "Mismatched arrays"); for (uint256 i = 0; i < recipients.length; i++) { - require(tokenBalances[token] >= amounts[i], "Insufficient balance"); + require(this.getBalance(token) >= amounts[i], "Insufficient balance"); if (token == address(0)) { // ETH distribution @@ -192,9 +186,6 @@ contract CommunityVault is HasSecurityContext { IERC20(token).safeTransfer(recipients[i], amounts[i]); } - // decrement balance - tokenBalances[token] -= amounts[i]; - // record the distribution rewardsDistributed[token][recipients[i]] += amounts[i]; diff --git a/test/CommunityVault.t.sol b/test/CommunityVault.t.sol index 1cd6036..119fba0 100644 --- a/test/CommunityVault.t.sol +++ b/test/CommunityVault.t.sol @@ -67,6 +67,7 @@ contract TestCommunityVault is DeploymentSetup { deal(address(testToken), depositor3, INITIAL_USER_BALANCE); } + // Test that getBalance gets the right balance function testGetCorrectBalance() public { uint256 depositAmount1 = 1000; @@ -75,31 +76,33 @@ contract TestCommunityVault is DeploymentSetup { uint256 withdrawAmount1 = 100; uint256 withdrawAmount2 = 120; uint256 withdrawAmount3 = 1400; + uint256 initialLootVaultBalance = IERC20(loot).balanceOf(address(vault)); + uint256 initialTestVaultBalance = IERC20(testToken).balanceOf(address(vault)); deposit(depositor1, IERC20(loot), depositAmount1); - assertEq(vault.getBalance(address(loot)), depositAmount1); + assertEq(vault.getBalance(address(loot)), initialLootVaultBalance + depositAmount1); deposit(depositor2, IERC20(loot), depositAmount2); - assertEq(vault.getBalance(address(loot)), depositAmount1 + depositAmount2); + assertEq(vault.getBalance(address(loot)), initialLootVaultBalance + depositAmount1 + depositAmount2); deposit(depositor3, IERC20(testToken), depositAmount3); - assertEq(vault.getBalance(address(loot)), depositAmount1 + depositAmount2); - assertEq(vault.getBalance(address(testToken)), depositAmount3); + assertEq(vault.getBalance(address(loot)), initialLootVaultBalance + depositAmount1 + depositAmount2); + assertEq(vault.getBalance(address(testToken)), initialTestVaultBalance + depositAmount3); vm.prank(admin); - vault.withdraw(address(loot), recipient1, depositAmount1); - assertEq(vault.getBalance(address(loot)), depositAmount2); - assertEq(vault.getBalance(address(testToken)), depositAmount3); + vault.withdraw(address(loot), recipient1, withdrawAmount1); + assertEq(vault.getBalance(address(loot)), initialLootVaultBalance + depositAmount1 + depositAmount2 - withdrawAmount1); + assertEq(vault.getBalance(address(testToken)), initialTestVaultBalance + depositAmount3); vm.prank(admin); - vault.withdraw(address(loot), recipient1, depositAmount2); - assertEq(vault.getBalance(address(loot)), 0); - assertEq(vault.getBalance(address(testToken)), depositAmount3); + vault.withdraw(address(loot), recipient1, withdrawAmount2); + assertEq(vault.getBalance(address(loot)), initialLootVaultBalance + depositAmount1 + depositAmount2 - withdrawAmount1 - withdrawAmount2); + assertEq(vault.getBalance(address(testToken)), initialTestVaultBalance + depositAmount3); vm.prank(admin); - vault.withdraw(address(testToken), recipient1, depositAmount3); - assertEq(vault.getBalance(address(loot)), 0); - assertEq(vault.getBalance(address(testToken)), 0); + vault.withdraw(address(testToken), recipient1, withdrawAmount3); + assertEq(vault.getBalance(address(loot)), initialLootVaultBalance + depositAmount1 + depositAmount2 - withdrawAmount1 - withdrawAmount2); + assertEq(vault.getBalance(address(testToken)), initialTestVaultBalance + depositAmount3 - withdrawAmount3); } // Test that deposit emits the Deposit event @@ -188,7 +191,7 @@ contract TestCommunityVault is DeploymentSetup { //check balances assertEq(depositor2.balance, (initialDepositor2Balance - depositAmount3)); - assertEq(address(vault).balance, (initialVaultBalance + depositAmount1 + depositAmount2 + depositAmount3)); //, 0.1e16); + assertEq(address(vault).balance, (initialVaultBalance + depositAmount1 + depositAmount2 + depositAmount3)); assertEq(vault.getBalance(address(0)), address(vault).balance); } @@ -200,12 +203,13 @@ contract TestCommunityVault is DeploymentSetup { // Test that withdraw Insufficient Balance error function testWithdrawInsufficientBalanceError() public { uint256 depositAmount = 1000; + uint256 initialLootVaultBalance = IERC20(loot).balanceOf(address(vault)); deposit(depositor1, IERC20(loot), depositAmount); vm.startPrank(admin); vm.expectRevert(); - vault.withdraw(address(loot), recipient1, depositAmount + 1); + vault.withdraw(address(loot), recipient1, initialLootVaultBalance + depositAmount + 1); vm.stopPrank(); } @@ -319,7 +323,7 @@ contract TestCommunityVault is DeploymentSetup { vm.prank(admin); vault.distribute(address(0), recipients, amounts); } - + // Test that distribute handles insufficient balances correctly function testDistributeInsufficientBalance() public { //TODO: testDistributeInsufficientBalance @@ -525,7 +529,6 @@ contract TestCommunityVault is DeploymentSetup { uint256 rewardsToDistribute = tracker.totalSalesCount(seller); assertEq(rewardsToDistribute, 1); - return; // Distribute rewards address[] memory recipients = new address[](1); @@ -539,6 +542,7 @@ contract TestCommunityVault is DeploymentSetup { assertEq(vault.rewardsDistributed(address(loot), seller), rewardsToDistribute, "Incorrect rewards tracked"); } + function deposit(address _depositor, IERC20 token, uint256 amount) private { vm.startPrank(_depositor); if (address(token) != address(0)) { @@ -546,9 +550,10 @@ contract TestCommunityVault is DeploymentSetup { vault.deposit(address(token), amount); } else { - address(vault).call{value: amount}( + (bool success,) = address(vault).call{value: amount}( abi.encodeWithSignature("deposit(address,uint256)", address(token), amount) ); + assertTrue(success); } vm.stopPrank(); diff --git a/test/PurchaseTracker.t.sol b/test/PurchaseTracker.t.sol index f250e2e..de9f281 100644 --- a/test/PurchaseTracker.t.sol +++ b/test/PurchaseTracker.t.sol @@ -348,7 +348,7 @@ contract TestPaymentAndTracker is DeploymentSetup { uint256 payAmount2 = 2000; // Place a payment using loot token - EscrowLib.Payment memory payment1 = placePayment( + placePayment( payEscrow, paymentId1, payer, seller, address(loot), payAmount1 );