diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index d1ab936..ac61140 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -24,7 +24,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v3 with: - python-version: "3.9" + python-version: "3.12" - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 4bf46dc..c8e15c7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10"] + python-version: ["3.10", "3.12"] steps: - uses: actions/checkout@v3 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ff73444..1773d74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ exclude: "^docs/conf.py" repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.1.0 + rev: v4.4.0 hooks: - id: trailing-whitespace - id: check-added-large-files @@ -23,19 +23,19 @@ repos: - id: isort - repo: https://github.com/psf/black - rev: stable + rev: 23.3.0 hooks: - id: black language_version: python3 - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + rev: 6.0.0 hooks: - id: flake8 ## You can add flake8 plugins via `additional_dependencies`: # additional_dependencies: [flake8-bugbear] - repo: https://github.com/zricethezav/gitleaks - rev: v8.12.0 + rev: v8.16.3 hooks: - id: gitleaks-docker diff --git a/src/ethproto/contracts.py b/src/ethproto/contracts.py index 6c6c616..baf5eef 100644 --- a/src/ethproto/contracts.py +++ b/src/ethproto/contracts.py @@ -4,6 +4,7 @@ from decimal import Decimal from functools import wraps +from environs import Env from m9g import Model from m9g.fields import DictField, IntField, ListField, StringField, TupleField @@ -13,11 +14,24 @@ __copyright__ = "Guillermo M. Narvaja" __license__ = "MIT" +env = Env() + +USE_CUSTOM_ERRORS = env.bool("USE_CUSTOM_ERRORS", False) + class RevertError(Exception): pass +class RevertCustomError(RevertError): + def __init__(self, error, *args): + self.error = error + self.args = args + + def __str__(self): + return f"{self.error}({', '.join(map(str, self.args))})" + + class WadField(IntField): FIELD_TYPE = Wad @@ -144,9 +158,9 @@ def track(self, contract): def _on_end(self): while self.modified_contracts: contract = self.modified_contracts.pop() - assert contract.serialize("pydict") == self.serialized_contracts[ - contract.contract_id - ], f"Contract {contract.contract_id} modified in view" + assert ( + contract.serialize("pydict") == self.serialized_contracts[contract.contract_id] + ), f"Contract {contract.contract_id} modified in view" del self.serialized_contracts[contract.contract_id] def archive(self): @@ -198,10 +212,11 @@ def inner(self, *args, **kwargs): if self.has_role(role, self.running_as): break else: - raise RevertError(f"AccessControl: account {self.running_as} is missing role {role}") + self._error("AccessControlUnauthorizedAccount", self.running_as, role) return method(self, *args, **kwargs) return inner + return decorator @@ -229,10 +244,14 @@ class Contract(Model): def __init__(self, contract_id=None, **kwargs): if contract_id is None: contract_id = f"{self.__class__.__name__}-{id(self)}" + self.use_custom_errors = kwargs.pop("use_custom_errors", USE_CUSTOM_ERRORS) super().__init__(contract_id=contract_id, **kwargs) self._versions = [] self.manager.add_contract(self.contract_id, self) + def _error(self, error_class, *args) -> RevertError: + return RevertCustomError(error_class, *args) + @contextmanager def as_(self, user): "Dummy as method to do the same with the wrapper" @@ -272,11 +291,7 @@ def pop_version(self, version_name=None): class AccessControlContract(Contract): owner = AddressField(default="owner") - roles = DictField( - StringField(), - TupleField((ListField(AddressField()), StringField())), - default={} - ) + roles = DictField(StringField(), TupleField((ListField(AddressField()), StringField())), default={}) set_attr_roles = {} @@ -298,6 +313,14 @@ def __init__(self, **kwargs): self._running_as = self.owner self.roles[""] = ([self.owner], "") # Add owner as default_admin + def _error(self, error_class, *args) -> RevertError: + if error_class == "AccessControlUnauthorizedAccount": + if self.use_custom_errors: + return RevertCustomError(error_class, args[0], args[1]) + else: + return RevertError(f"AccessControl: account {args[0]} is missing role {args[1]}") + return super()._error(error_class, *args) + @contextmanager def _disable_role_validation(self): self._role_validation_disabled = True @@ -320,8 +343,10 @@ def grant_role(self, role, user): members, admin_role = self.roles[role] else: members, admin_role = [], "" - require(self.has_role(admin_role, self._running_as), - f"AccessControl: AccessControl: account {self._running_as} is missing role '{admin_role}'") + require( + self.has_role(admin_role, self._running_as), + self._error("AccessControlUnauthorizedAccount", self._running_as, admin_role), + ) if user not in members: members.append(user) @@ -337,13 +362,16 @@ def _validate_setattr(self, attr_name, value): if attr_name in self.set_attr_roles: require( self.has_role(self.set_attr_roles[attr_name], self._running_as), - f"AccessControl: AccessControl: account {self._running_as} is missing role " - f"'{self.set_attr_roles[attr_name]}'" + self._error( + "AccessControlUnauthorizedAccount", self._running_as, self.set_attr_roles[attr_name] + ), ) def require(condition, message=None): if not condition: + if isinstance(message, RevertError): + raise message raise RevertError(message or "required condition not met") @@ -354,14 +382,43 @@ class ERC20Token(AccessControlContract): symbol = StringField(default="") decimals = IntField(default=18) balances = DictField(AddressField(), WadField(), default={}) - allowances = DictField( - TupleField((AddressField(), AddressField())), - WadField(), - default={} - ) + allowances = DictField(TupleField((AddressField(), AddressField())), WadField(), default={}) _total_supply = WadField(default=ZERO) + _arg_count_by_error = { + "ERC20InsufficientBalance": 3, + "ERC20InvalidSender": 1, + "ERC20InvalidReceiver": 1, + "ERC20InsufficientAllowance": 3, + "ERC20InvalidApprover": 1, + "ERC20InvalidSpender": 1, + } + + _message_by_error = { + "ERC20InsufficientBalance": "ERC20: transfer amount exceeds balance", + "ERC20InvalidSender": "ERC20: transfer from the zero address", + "ERC20InvalidReceiver": "ERC20: transfer to the zero address", + "ERC20InsufficientAllowance": "ERC20: insufficient allowance", + "ERC20InvalidApprover": "ERC20: approve from the zero address", + "ERC20InvalidSpender": "ERC20: approve to the zero address", + } + + def _error(self, error_class, *args) -> RevertError: + if self.use_custom_errors: + arg_count = self._arg_count_by_error.get(error_class, None) + if arg_count == 1: + return RevertCustomError( + error_class, args[0] if args else "0x0000000000000000000000000000000000000000" + ) + elif arg_count is not None: + return RevertCustomError(error_class, *args[:arg_count]) + else: + message = self._message_by_error.get(error_class, None) + if message is not None: + return RevertError(message) + return super()._error(error_class, *args) + def __init__(self, **kwargs): if "initial_supply" in kwargs: initial_supply = kwargs.pop("initial_supply") @@ -410,7 +467,7 @@ def transfer(self, sender, recipient, amount): def _transfer(self, sender, recipient, amount): sender, recipient = self._parse_accounts(sender, recipient) if self.balance_of(sender) < amount: - raise RevertError("ERC20: transfer amount exceeds balance") + raise self._error("ERC20InsufficientBalance", sender, self.balance_of(sender), amount) elif self.balances[sender] == amount: del self.balances[sender] else: @@ -425,8 +482,8 @@ def allowance(self, owner, spender): def _approve(self, owner, spender, amount): owner, spender = self._parse_accounts(owner, spender) - require(owner is not None, "ERC20: approve from the zero address") - require(spender is not None, "ERC20: approve to the zero address") + require(owner is not None, self._error("ERC20InvalidApprover")) + require(spender is not None, self._error("ERC20InvalidSpender", spender)) if amount == self.ZERO: try: del self.allowances[(owner, spender)] @@ -447,7 +504,7 @@ def increase_allowance(self, sender, spender, amount): def decrease_allowance(self, sender, spender, amount): sender, spender = self._parse_accounts(sender, spender) allowance = self.allowances.get((sender, spender), self.ZERO) - require(allowance >= amount, "ERC20: decreased allowance below zero") + require(allowance >= amount, self._error("ERC20InsufficientAllowance", spender, allowance, amount)) self._approve(sender, spender, allowance - amount) @external @@ -455,7 +512,7 @@ def transfer_from(self, spender, sender, recipient, amount): spender, sender, recipient = self._parse_accounts(spender, sender, recipient) allowance = self.allowances.get((sender, spender), self.ZERO) if allowance < amount: - raise RevertError("ERC20: transfer amount exceeds allowance") + raise self._error("ERC20InsufficientAllowance", spender, allowance, amount) self._transfer(sender, recipient, amount) self._approve(sender, spender, allowance - amount) return True @@ -464,7 +521,7 @@ def total_supply(self): return self._total_supply -class ERC721Token(AccessControlContract): # NFT +class ERC721Token(AccessControlContract): # NFT ZERO = Wad(0) name = StringField() @@ -477,13 +534,49 @@ class ERC721Token(AccessControlContract): # NFT _token_count = IntField(default=0) + _arg_count_by_error = { + "ERC721InvalidOwner": 1, + "ERC721NonexistentToken": 1, + "ERC721IncorrectOwner": 3, + "ERC721InvalidSender": 1, + "ERC721InvalidReceiver": 1, + "ERC721InsufficientApproval": 2, + } + + _message_by_error = { + "ERC721InvalidOwner": "ERC721: address zero is not a valid owner", + "ERC721NonexistentToken": "ERC721: invalid token ID", + "ERC721IncorrectOwner": "ERC721: transfer from incorrect owner", + "ERC721InvalidSender": "ERC721: transfer from incorrect owner", + "ERC721InvalidReceiver": "ERC721: transfer to the zero address", + "ERC721InsufficientApproval": "ERC721: caller is not token owner nor approved", + } + + def _error(self, error_class, *args) -> RevertError: + if self.use_custom_errors: + arg_count = self._arg_count_by_error.get(error_class, None) + if arg_count == 1: + return RevertCustomError( + error_class, args[0] if args else "0x0000000000000000000000000000000000000000" + ) + elif arg_count is not None: + return RevertCustomError(error_class, *args[:arg_count]) + else: + message = self._message_by_error.get(error_class, None) + if message is not None: + return RevertError(message) + return super()._error(error_class, *args) + @external def mint(self, to, token_id): if token_id is None: self._token_count += 1 token_id = self._token_count if token_id in self.owners: - raise RevertError("ERC721: token already minted") + if self.use_custom_errors: + raise RevertError("ERC721: token already minted") + else: + raise self._error("ERC721InvalidSender") self.balances[to] = self.balances.get(to, 0) + 1 self.owners[token_id] = to @@ -503,7 +596,7 @@ def balance_of(self, address): @view def owner_of(self, token_id): if token_id not in self.owners: - raise RevertError("ERC721: invalid token ID") + raise self._error("ERC721NonexistentToken", token_id) return self.owners[token_id] # def token_uri @@ -536,23 +629,29 @@ def is_approved_for_all(self, owner, operator): @external def transfer_from(self, sender, from_, to, token_id): owner = self.owners[token_id] - if sender != owner and self.token_approvals.get(token_id, None) != sender and \ - sender not in self.operator_approvals.get(owner, []): - raise RevertError("ERC721: caller is not token owner or approved") + if ( + sender != owner + and self.token_approvals.get(token_id, None) != sender + and sender not in self.operator_approvals.get(owner, []) + ): + raise self._error("ERC721InsufficientApproval", sender, token_id) return self._transfer(from_, to, token_id) @external def safe_transfer_from(self, sender, from_, to, token_id): owner = self.owners[token_id] - if sender != owner and self.token_approvals.get(token_id, None) != sender and \ - sender not in self.operator_approvals.get(owner, []): - raise RevertError("ERC721: caller is not token owner or approved") + if ( + sender != owner + and self.token_approvals.get(token_id, None) != sender + and sender not in self.operator_approvals.get(owner, []) + ): + raise self._error("ERC721InsufficientApproval", sender, token_id) # TODO: if `to` is contract, call onERC721Received return self._transfer(from_, to, token_id) def _transfer(self, from_, to, token_id): if self.owners[token_id] != from_: - raise RevertError("ERC721: transfer of token that is not own:") + raise self._error("ERC721InvalidOwner", from_) if token_id in self.token_approvals: del self.token_approvals[token_id] self.balances[from_] -= 1 diff --git a/tests/test_contracts.py b/tests/test_contracts.py index e0dce95..132c60f 100644 --- a/tests/test_contracts.py +++ b/tests/test_contracts.py @@ -140,6 +140,8 @@ def _connected_contract_address(eth_wrapper_class, *args, **kwargs): ERC20TokenAlternatives = [ERC20Token] +ERC20TokenAlternatives.append(partial(ERC20Token, use_custom_errors=False)) +ERC20TokenAlternatives.append(partial(ERC20Token, use_custom_errors=True)) if "web3py" in TEST_ENV: ERC20TokenAlternatives.append(partial(TestCurrency, provider_key="w3")) @@ -215,8 +217,13 @@ def test_approve_flow(self, token_class): token.transfer_from("Spender", "owner", "Giacomo", _W(300)) assert token.allowance("owner", "Spender") == _W(0) - with pytest.raises(RevertError): - token.transfer_from("Spender", "owner", "Luca", _W(1)) + with pytest.raises(RevertError, match="allowance|ERC20InsufficientAllowance"): + try: + token.transfer_from("Spender", "owner", "Luca", _W(1)) + except RevertError as err: + if getattr(token, "use_custom_errors", False): + assert str(err).startswith("ERC20InsufficientAllowance(") + raise assert token.balance_of("Guillo") == _W(200) assert token.balance_of("owner") == _W(1500) @@ -225,6 +232,8 @@ def test_approve_flow(self, token_class): ERC721TokenAlternatives = [ERC721Token] +ERC721TokenAlternatives.append(partial(ERC721Token, use_custom_errors=False)) +ERC721TokenAlternatives.append(partial(ERC721Token, use_custom_errors=True)) if "web3py" in TEST_ENV: ERC721TokenAlternatives.append(partial(TestNFT, provider_key="w3")) @@ -242,7 +251,7 @@ def test_mint_burn(self, token_class): assert nft.owner_of(1235) == "CUST1" nft.burn("CUST1", 1235) assert nft.balance_of("CUST1") == 1 - with pytest.raises(RevertError, match="ERC721: invalid token ID"): + with pytest.raises(RevertError, match="ERC721: invalid token ID|ERC721NonexistentToken"): nft.owner_of(1235) nft.burn("CUST1", 1234) assert nft.balance_of("CUST1") == 0 @@ -297,5 +306,12 @@ def test_approve_for_all(self, token_class): assert nft.balance_of("CUST2") == 2 nft.set_approval_for_all("CUST1", "SPEND", False) - with pytest.raises(RevertError, match="ERC721: caller is not token owner or approved"): + with pytest.raises( + RevertError, + match=( + "ERC721InsufficientApproval" + if getattr(nft, "use_custom_errors", False) + else "ERC721: caller is not token owner" + ), + ): nft.transfer_from("SPEND", "CUST1", "CUST2", 1235) diff --git a/tox.ini b/tox.ini index 7a88287..8ff85e5 100644 --- a/tox.ini +++ b/tox.ini @@ -4,12 +4,12 @@ [tox] minversion = 3.15 -envlist = {py39,py310} +envlist = {py310,py312} [gh-actions] python = - 3.9: py39 3.10: py310 + 3.12: py312 [testenv] description = invoke pytest to run automated tests