Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: "3.9"
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10"]
python-version: ["3.10", "3.12"]

steps:
- uses: actions/checkout@v3
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ exclude: "^docs/conf.py"

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: check-added-large-files
Expand All @@ -23,19 +23,19 @@ repos:
- id: isort

- repo: https://github.com/psf/black
rev: stable
rev: 23.3.0
hooks:
- id: black
language_version: python3

- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
rev: 6.0.0
hooks:
- id: flake8
## You can add flake8 plugins via `additional_dependencies`:
# additional_dependencies: [flake8-bugbear]

- repo: https://github.com/zricethezav/gitleaks
rev: v8.12.0
rev: v8.16.3
hooks:
- id: gitleaks-docker
179 changes: 146 additions & 33 deletions src/ethproto/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from decimal import Decimal
from functools import wraps

from environs import Env
from m9g import Model
from m9g.fields import DictField, IntField, ListField, StringField, TupleField

Expand All @@ -13,11 +14,24 @@
__copyright__ = "Guillermo M. Narvaja"
__license__ = "MIT"

env = Env()

USE_CUSTOM_ERRORS = env.bool("USE_CUSTOM_ERRORS", False)


class RevertError(Exception):
pass


class RevertCustomError(RevertError):
def __init__(self, error, *args):
self.error = error
self.args = args

def __str__(self):
return f"{self.error}({', '.join(map(str, self.args))})"


class WadField(IntField):
FIELD_TYPE = Wad

Expand Down Expand Up @@ -144,9 +158,9 @@ def track(self, contract):
def _on_end(self):
while self.modified_contracts:
contract = self.modified_contracts.pop()
assert contract.serialize("pydict") == self.serialized_contracts[
contract.contract_id
], f"Contract {contract.contract_id} modified in view"
assert (
contract.serialize("pydict") == self.serialized_contracts[contract.contract_id]
), f"Contract {contract.contract_id} modified in view"
del self.serialized_contracts[contract.contract_id]

def archive(self):
Expand Down Expand Up @@ -198,10 +212,11 @@ def inner(self, *args, **kwargs):
if self.has_role(role, self.running_as):
break
else:
raise RevertError(f"AccessControl: account {self.running_as} is missing role {role}")
self._error("AccessControlUnauthorizedAccount", self.running_as, role)
return method(self, *args, **kwargs)

return inner

return decorator


Expand Down Expand Up @@ -229,10 +244,14 @@ class Contract(Model):
def __init__(self, contract_id=None, **kwargs):
if contract_id is None:
contract_id = f"{self.__class__.__name__}-{id(self)}"
self.use_custom_errors = kwargs.pop("use_custom_errors", USE_CUSTOM_ERRORS)
super().__init__(contract_id=contract_id, **kwargs)
self._versions = []
self.manager.add_contract(self.contract_id, self)

def _error(self, error_class, *args) -> RevertError:
return RevertCustomError(error_class, *args)

@contextmanager
def as_(self, user):
"Dummy as method to do the same with the wrapper"
Expand Down Expand Up @@ -272,11 +291,7 @@ def pop_version(self, version_name=None):

class AccessControlContract(Contract):
owner = AddressField(default="owner")
roles = DictField(
StringField(),
TupleField((ListField(AddressField()), StringField())),
default={}
)
roles = DictField(StringField(), TupleField((ListField(AddressField()), StringField())), default={})

set_attr_roles = {}

Expand All @@ -298,6 +313,14 @@ def __init__(self, **kwargs):
self._running_as = self.owner
self.roles[""] = ([self.owner], "") # Add owner as default_admin

def _error(self, error_class, *args) -> RevertError:
if error_class == "AccessControlUnauthorizedAccount":
if self.use_custom_errors:
return RevertCustomError(error_class, args[0], args[1])
else:
return RevertError(f"AccessControl: account {args[0]} is missing role {args[1]}")
return super()._error(error_class, *args)

@contextmanager
def _disable_role_validation(self):
self._role_validation_disabled = True
Expand All @@ -320,8 +343,10 @@ def grant_role(self, role, user):
members, admin_role = self.roles[role]
else:
members, admin_role = [], ""
require(self.has_role(admin_role, self._running_as),
f"AccessControl: AccessControl: account {self._running_as} is missing role '{admin_role}'")
require(
self.has_role(admin_role, self._running_as),
self._error("AccessControlUnauthorizedAccount", self._running_as, admin_role),
)

if user not in members:
members.append(user)
Expand All @@ -337,13 +362,16 @@ def _validate_setattr(self, attr_name, value):
if attr_name in self.set_attr_roles:
require(
self.has_role(self.set_attr_roles[attr_name], self._running_as),
f"AccessControl: AccessControl: account {self._running_as} is missing role "
f"'{self.set_attr_roles[attr_name]}'"
self._error(
"AccessControlUnauthorizedAccount", self._running_as, self.set_attr_roles[attr_name]
),
)


def require(condition, message=None):
if not condition:
if isinstance(message, RevertError):
raise message
raise RevertError(message or "required condition not met")


Expand All @@ -354,14 +382,51 @@ class ERC20Token(AccessControlContract):
symbol = StringField(default="")
decimals = IntField(default=18)
balances = DictField(AddressField(), WadField(), default={})
allowances = DictField(
TupleField((AddressField(), AddressField())),
WadField(),
default={}
)
allowances = DictField(TupleField((AddressField(), AddressField())), WadField(), default={})

_total_supply = WadField(default=ZERO)

def _error(self, error_class, *args) -> RevertError:
if error_class == "ERC20InsufficientBalance":
if self.use_custom_errors:
return RevertCustomError(error_class, args[0], args[1], args[2])
else:
return RevertError("ERC20: transfer amount exceeds balance")
elif error_class == "ERC20InvalidSender":
if self.use_custom_errors:
return RevertCustomError(
error_class, args[0] if args else "0x0000000000000000000000000000000000000000"
)
else:
return RevertError("ERC20: transfer from the zero address")
elif error_class == "ERC20InvalidReceiver":
if self.use_custom_errors:
return RevertCustomError(
error_class, args[0] if args else "0x0000000000000000000000000000000000000000"
)
else:
return RevertError("ERC20: transfer to the zero address")
elif error_class == "ERC20InsufficientAllowance":
if self.use_custom_errors:
return RevertCustomError(error_class, args[0], args[1], args[2])
else:
return RevertError("ERC20: insufficient allowance")
elif error_class == "ERC20InvalidApprover":
if self.use_custom_errors:
return RevertCustomError(
error_class, args[0] if args else "0x0000000000000000000000000000000000000000"
)
else:
return RevertError("ERC20: approve from the zero address")
elif error_class == "ERC20InvalidSpender":
if self.use_custom_errors:
return RevertCustomError(
error_class, args[0] if args else "0x0000000000000000000000000000000000000000"
)
else:
return RevertError("ERC20: approve to the zero address")
return super()._error(error_class, *args)

def __init__(self, **kwargs):
if "initial_supply" in kwargs:
initial_supply = kwargs.pop("initial_supply")
Expand Down Expand Up @@ -410,7 +475,7 @@ def transfer(self, sender, recipient, amount):
def _transfer(self, sender, recipient, amount):
sender, recipient = self._parse_accounts(sender, recipient)
if self.balance_of(sender) < amount:
raise RevertError("ERC20: transfer amount exceeds balance")
raise self._error("ERC20InsufficientBalance", sender, self.balance_of(sender), amount)
elif self.balances[sender] == amount:
del self.balances[sender]
else:
Expand All @@ -425,8 +490,8 @@ def allowance(self, owner, spender):

def _approve(self, owner, spender, amount):
owner, spender = self._parse_accounts(owner, spender)
require(owner is not None, "ERC20: approve from the zero address")
require(spender is not None, "ERC20: approve to the zero address")
require(owner is not None, self._error("ERC20InvalidApprover"))
require(spender is not None, self._error("ERC20InvalidSpender", spender))
if amount == self.ZERO:
try:
del self.allowances[(owner, spender)]
Expand All @@ -447,15 +512,15 @@ def increase_allowance(self, sender, spender, amount):
def decrease_allowance(self, sender, spender, amount):
sender, spender = self._parse_accounts(sender, spender)
allowance = self.allowances.get((sender, spender), self.ZERO)
require(allowance >= amount, "ERC20: decreased allowance below zero")
require(allowance >= amount, self._error("ERC20InsufficientAllowance", spender, allowance, amount))
self._approve(sender, spender, allowance - amount)

@external
def transfer_from(self, spender, sender, recipient, amount):
spender, sender, recipient = self._parse_accounts(spender, sender, recipient)
allowance = self.allowances.get((sender, spender), self.ZERO)
if allowance < amount:
raise RevertError("ERC20: transfer amount exceeds allowance")
raise self._error("ERC20InsufficientAllowance", spender, allowance, amount)
self._transfer(sender, recipient, amount)
self._approve(sender, spender, allowance - amount)
return True
Expand All @@ -464,7 +529,7 @@ def total_supply(self):
return self._total_supply


class ERC721Token(AccessControlContract): # NFT
class ERC721Token(AccessControlContract): # NFT
ZERO = Wad(0)

name = StringField()
Expand All @@ -477,13 +542,55 @@ class ERC721Token(AccessControlContract): # NFT

_token_count = IntField(default=0)

def _error(self, error_class, *args) -> RevertError:
if error_class == "ERC721InvalidOwner":
if self.use_custom_errors:
return RevertCustomError(
error_class, args[0] if args else "0x0000000000000000000000000000000000000000"
)
else:
return RevertError("ERC721: address zero is not a valid owner")
elif error_class == "ERC721NonexistentToken":
if self.use_custom_errors:
return RevertCustomError(error_class, args[0])
else:
return RevertError("ERC721: invalid token ID")
elif error_class == "ERC721IncorrectOwner":
if self.use_custom_errors:
return RevertCustomError(error_class, args[0], args[1], args[2])
else:
return RevertError("ERC721: transfer from incorrect owner")
elif error_class == "ERC721InvalidSender":
if self.use_custom_errors:
return RevertCustomError(
error_class, args[0] if args else "0x0000000000000000000000000000000000000000"
)
else:
return RevertError("ERC721: transfer from incorrect owner")
elif error_class == "ERC721InvalidReceiver":
if self.use_custom_errors:
return RevertCustomError(
error_class, args[0] if args else "0x0000000000000000000000000000000000000000"
)
else:
return RevertError("ERC721: transfer to the zero address")
elif error_class == "ERC721InsufficientApproval":
if self.use_custom_errors:
return RevertCustomError(error_class, args[0], args[1])
else:
return RevertError("ERC721: caller is not token owner nor approved")
return super()._error(error_class, *args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Podría aplicar acá lo mismo que decía arriba, solo si la opción del mapping es viable.


@external
def mint(self, to, token_id):
if token_id is None:
self._token_count += 1
token_id = self._token_count
if token_id in self.owners:
raise RevertError("ERC721: token already minted")
if self.use_custom_errors:
raise RevertError("ERC721: token already minted")
else:
raise self._error("ERC721InvalidSender")
self.balances[to] = self.balances.get(to, 0) + 1
self.owners[token_id] = to

Expand All @@ -503,7 +610,7 @@ def balance_of(self, address):
@view
def owner_of(self, token_id):
if token_id not in self.owners:
raise RevertError("ERC721: invalid token ID")
raise self._error("ERC721NonexistentToken", token_id)
return self.owners[token_id]

# def token_uri
Expand Down Expand Up @@ -536,23 +643,29 @@ def is_approved_for_all(self, owner, operator):
@external
def transfer_from(self, sender, from_, to, token_id):
owner = self.owners[token_id]
if sender != owner and self.token_approvals.get(token_id, None) != sender and \
sender not in self.operator_approvals.get(owner, []):
raise RevertError("ERC721: caller is not token owner or approved")
if (
sender != owner
and self.token_approvals.get(token_id, None) != sender
and sender not in self.operator_approvals.get(owner, [])
):
raise self._error("ERC721InsufficientApproval", sender, token_id)
return self._transfer(from_, to, token_id)

@external
def safe_transfer_from(self, sender, from_, to, token_id):
owner = self.owners[token_id]
if sender != owner and self.token_approvals.get(token_id, None) != sender and \
sender not in self.operator_approvals.get(owner, []):
raise RevertError("ERC721: caller is not token owner or approved")
if (
sender != owner
and self.token_approvals.get(token_id, None) != sender
and sender not in self.operator_approvals.get(owner, [])
):
raise self._error("ERC721InsufficientApproval", sender, token_id)
# TODO: if `to` is contract, call onERC721Received
return self._transfer(from_, to, token_id)

def _transfer(self, from_, to, token_id):
if self.owners[token_id] != from_:
raise RevertError("ERC721: transfer of token that is not own:")
raise self._error("ERC721InvalidOwner", from_)
if token_id in self.token_approvals:
del self.token_approvals[token_id]
self.balances[from_] -= 1
Expand Down
Loading